diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 5e290c5d02..152ff3b648 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -9,6 +9,7 @@ permissions: jobs: autofix: + if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 16a1268cb1..775f6f351f 100644 --- a/README.md +++ b/README.md @@ -235,6 +235,10 @@ Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https:/ One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Deploy to AKS with Azure Devops Pipeline + +One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contributing diff --git a/README_AR.md b/README_AR.md index d2cb0098a3..e7a4dbdb27 100644 --- a/README_AR.md +++ b/README_AR.md @@ -217,6 +217,10 @@ docker compose up -d انشر ​​Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### استخدام Azure Devops Pipeline للنشر على AKS + +انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## المساهمة diff --git a/README_BN.md b/README_BN.md index f57413ec8b..e4da437eff 100644 --- a/README_BN.md +++ b/README_BN.md @@ -235,6 +235,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + #### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন + ## Contributing diff --git a/README_CN.md b/README_CN.md index e9c73eb48b..82149519d3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -233,6 +233,9 @@ docker compose up -d 使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS ## Star History diff --git a/README_DE.md b/README_DE.md index d31a56542d..2420ac0392 100644 --- a/README_DE.md +++ b/README_DE.md @@ -230,6 +230,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung + +Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden + ## Contributing diff --git a/README_ES.md b/README_ES.md index 918bfe2286..4fa59dc18f 100644 --- a/README_ES.md +++ b/README_ES.md @@ -230,6 +230,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uso de Azure Devops Pipeline para implementar en AKS + +Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuir diff --git a/README_FR.md b/README_FR.md index 56ca878aae..dcbc869620 100644 --- a/README_FR.md +++ b/README_FR.md @@ -228,6 +228,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS + +Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuer diff --git a/README_JA.md b/README_JA.md index 6d277a36ed..d840fd6419 100644 --- a/README_JA.md +++ b/README_JA.md @@ -227,6 +227,10 @@ docker compose up -d #### Alibaba Cloud Data Management [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます +#### AKSへのデプロイにAzure Devops Pipelineを使用 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ + ## 貢献 diff --git a/README_KL.md b/README_KL.md index dac67eeb29..41c7969e1c 100644 --- a/README_KL.md +++ b/README_KL.md @@ -228,6 +228,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy + ## Contributing diff --git a/README_KR.md b/README_KR.md index 072481da02..d4b31a8928 100644 --- a/README_KR.md +++ b/README_KR.md @@ -222,6 +222,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다 +#### AKS에 배포하기 위해 Azure Devops Pipeline 사용 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포 + ## 기여 diff --git a/README_PT.md b/README_PT.md index 1260f8e6fd..94452cb233 100644 --- a/README_PT.md +++ b/README_PT.md @@ -227,6 +227,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Usando Azure Devops Pipeline para Implantar no AKS + +Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuindo diff --git a/README_SI.md b/README_SI.md index 7ded001d86..d840e9155f 100644 --- a/README_SI.md +++ b/README_SI.md @@ -228,6 +228,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uporaba Azure Devops Pipeline za uvajanje v AKS + +Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Prispevam diff --git a/README_TR.md b/README_TR.md index 37953f0de1..470a7570e0 100644 --- a/README_TR.md +++ b/README_TR.md @@ -221,6 +221,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın +#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın + ## Katkıda Bulunma diff --git a/README_TW.md b/README_TW.md index f70d6a25f6..18f1d2754a 100644 --- a/README_TW.md +++ b/README_TW.md @@ -233,6 +233,10 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS + ## 貢獻 diff --git a/README_VI.md b/README_VI.md index ddd9aa95f6..2ab6da80fc 100644 --- a/README_VI.md +++ b/README_VI.md @@ -224,6 +224,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS + +Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Đóng góp diff --git a/api/.env.example b/api/.env.example index 18f2dbf647..4beabfecea 100644 --- a/api/.env.example +++ b/api/.env.example @@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false # Tidb Vector configuration TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 68b16e48db..ff290ff99d 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -215,7 +215,7 @@ class DatabaseConfig(BaseSettings): class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description="Backend for Celery task results. Options: 'database', 'redis'.", + description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.", default="redis", ) @@ -245,7 +245,12 @@ class CeleryConfig(DatabaseConfig): @computed_field def CELERY_RESULT_BACKEND(self) -> str | None: - return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL + if self.CELERY_BACKEND in ("database", "rabbitmq"): + return f"db+{self.SQLALCHEMY_DATABASE_URI}" + elif self.CELERY_BACKEND == "redis": + return self.CELERY_BROKER_URL + else: + return None @property def BROKER_USE_SSL(self) -> bool: diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index c4dcc0d465..1aab01c6e1 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings): description="AccessKey secret for the instance name", default=None, ) + + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field( + description="Whether to normalize full-text search scores to [0, 1]", + default=False, + ) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 2af7136f14..007b1f6d3d 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -100,7 +100,7 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -class AnnotationListApi(Resource): +class AnnotationApi(Resource): @setup_required @login_required @account_initialization_required @@ -123,6 +123,23 @@ class AnnotationListApi(Resource): } return response, 200 + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("annotation") + @marshal_with(annotation_fields) + def post(self, app_id): + if not current_user.is_editor: + raise Forbidden() + + app_id = str(app_id) + parser = reqparse.RequestParser() + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + args = parser.parse_args() + annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) + return annotation + @setup_required @login_required @account_initialization_required @@ -137,7 +154,8 @@ class AnnotationListApi(Resource): # If annotation_ids are provided, handle batch deletion if annotation_ids: - if not annotation_ids: + # Check if any annotation_ids contain empty strings or invalid values + if not all(annotation_id.strip() for annotation_id in annotation_ids if annotation_id): return { "code": "bad_request", "message": "annotation_ids are required if the parameter is provided.", @@ -165,25 +183,6 @@ class AnnotationExportApi(Resource): return response, 200 -class AnnotationCreateApi(Resource): - @setup_required - @login_required - @account_initialization_required - @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) - def post(self, app_id): - if not current_user.is_editor: - raise Forbidden() - - app_id = str(app_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() - annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) - return annotation - - class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @@ -292,9 +291,8 @@ api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply api.add_resource( AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" ) -api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationApi, "/apps//annotations") api.add_resource(AnnotationExportApi, "/apps//annotations/export") -api.add_resource(AnnotationCreateApi, "/apps//annotations") api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 36a7905572..79c860e6b8 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,7 +1,9 @@ +import json + from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range from sqlalchemy.orm import Session -from werkzeug.exceptions import NotFound +from werkzeug.exceptions import BadRequest, NotFound import services from controllers.service_api import api @@ -15,6 +17,7 @@ from fields.conversation_fields import ( simple_conversation_fields, ) from fields.conversation_variable_fields import ( + conversation_variable_fields, conversation_variable_infinite_scroll_pagination_fields, ) from libs.helper import uuid_value @@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource): raise NotFound("Conversation Not Exists.") +class ConversationVariableDetailApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @marshal_with(conversation_variable_fields) + def put(self, app_model: App, end_user: EndUser, c_id, variable_id): + """Update a conversation variable's value""" + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + conversation_id = str(c_id) + variable_id = str(variable_id) + + parser = reqparse.RequestParser() + parser.add_argument("value", required=True, location="json") + args = parser.parse_args() + + try: + return ConversationService.update_conversation_variable( + app_model, conversation_id, variable_id, end_user, json.loads(args["value"]) + ) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationVariableNotExistsError: + raise NotFound("Conversation Variable Not Exists.") + except services.errors.conversation.ConversationVariableTypeMismatchError as e: + raise BadRequest(str(e)) + + api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") api.add_resource(ConversationApi, "/conversations") api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables") +api.add_resource( + ConversationVariableDetailApi, + "/conversations//variables/", + endpoint="conversation_variable_detail", + methods=["PUT"], +) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 75bd2f677a..0df0aa59b2 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -148,6 +148,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index a20f2485c8..e7c90c1229 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -4,6 +4,7 @@ import logging import os from datetime import datetime, timedelta from typing import Any, Optional, Union, cast +from urllib.parse import urlparse from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -40,8 +41,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra try: # Choose the appropriate exporter based on config type exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter] + + # Inspect the provided endpoint to determine its structure + parsed = urlparse(arize_phoenix_config.endpoint) + base_endpoint = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + if isinstance(arize_phoenix_config, ArizeConfig): - arize_endpoint = f"{arize_phoenix_config.endpoint}/v1" + arize_endpoint = f"{base_endpoint}/v1" arize_headers = { "api_key": arize_phoenix_config.api_key or "", "space_id": arize_phoenix_config.space_id or "", @@ -53,7 +60,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra timeout=30, ) else: - phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces" + phoenix_endpoint = f"{base_endpoint}{path}/v1/traces" phoenix_headers = { "api_key": arize_phoenix_config.api_key or "", "authorization": f"Bearer {arize_phoenix_config.api_key or ''}", diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 626782cee5..851a77fbc1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -87,7 +87,7 @@ class PhoenixConfig(BaseTracingConfig): @field_validator("endpoint") @classmethod def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com") + return validate_url_with_path(v, "https://app.phoenix.arize.com") class LangfuseConfig(BaseTracingConfig): diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index a607c76beb..7eb5da7e3a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -322,7 +322,7 @@ class OpsTraceManager: :return: """ # auth check - if enabled == True: + if enabled: try: provider_config_map[tracing_provider] except KeyError: @@ -407,7 +407,6 @@ class TraceTask: def __init__( self, trace_type: Any, - trace_id: Optional[str] = None, message_id: Optional[str] = None, workflow_execution: Optional[WorkflowExecution] = None, conversation_id: Optional[str] = None, @@ -423,7 +422,7 @@ class TraceTask: self.timer = timer self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.app_id = None - + self.trace_id = None self.kwargs = kwargs external_trace_id = kwargs.get("external_trace_id") if external_trace_id: diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 784e27fc7f..91d667ff2c 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -1,5 +1,6 @@ import json import logging +import math from typing import Any, Optional import tablestore # type: ignore @@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel): access_key_secret: Optional[str] = None instance_name: Optional[str] = None endpoint: Optional[str] = None + normalize_full_text_bm25_score: Optional[bool] = False @model_validator(mode="before") @classmethod @@ -47,6 +49,7 @@ class TableStoreVector(BaseVector): config.access_key_secret, config.instance_name, ) + self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score self._table_name = f"{collection_name}" self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" @@ -131,8 +134,8 @@ class TableStoreVector(BaseVector): filtered_list = None if document_ids_filter: filtered_list = ["document_id=" + item for item in document_ids_filter] - - return self._search_by_full_text(query, filtered_list, top_k) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._search_by_full_text(query, filtered_list, top_k, score_threshold) def delete(self) -> None: self._delete_table_if_exist() @@ -318,7 +321,19 @@ class TableStoreVector(BaseVector): documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: + @staticmethod + def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float: + """ + Args: + score: BM25 search score. + k: decay factor, the larger the k, the steeper the low score end + """ + normalized_score = 1 - math.exp(-k * score) + return max(0.0, min(1.0, normalized_score)) + + def _search_by_full_text( + self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float + ) -> list[Document]: bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) @@ -339,15 +354,27 @@ class TableStoreVector(BaseVector): documents = [] for search_hit in search_response.search_hits: + score = None + if self._normalize_full_text_bm25_score: + score = self._normalize_score_exp_decay(search_hit.score) + + # skip when score is below threshold and use normalize score + if score and score <= score_threshold: + continue + ots_column_map = {} for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - vector_str = ots_column_map.get(Field.VECTOR.value) metadata_str = ots_column_map.get(Field.METADATA_KEY.value) - vector = json.loads(vector_str) if vector_str else None metadata = json.loads(metadata_str) if metadata_str else {} + vector_str = ots_column_map.get(Field.VECTOR.value) + vector = json.loads(vector_str) if vector_str else None + + if score: + metadata["score"] = score + documents.append( Document( page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", @@ -355,6 +382,8 @@ class TableStoreVector(BaseVector): metadata=metadata, ) ) + if self._normalize_full_text_bm25_score: + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents @@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory): instance_name=dify_config.TABLESTORE_INSTANCE_NAME, access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID, access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET, + normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE, ), ) diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index 6ef932ad22..1f054bccdb 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 875626eb34..17f4d1af2d 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,5 +1,6 @@ import json import logging +import operator from typing import Any, Optional, cast import requests @@ -130,13 +131,15 @@ class NotionExtractor(BaseExtractor): data[property_name] = value row_dict = {k: v for k, v in data.items() if v} row_content = "" - for key, value in row_dict.items(): + for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) row_content = row_content + f"{key}:{value_content}\n" else: row_content = row_content + f"{key}:{value}\n" + if "url" in result: + row_content = row_content + f"Row Page URL:{result.get('url', '')}\n" database_content.append(row_content) has_more = response_data.get("has_more", False) diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index f9b776b3b9..91316b859a 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool): target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore if not target_time: yield self.create_text_message( - f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" + f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}" ) return diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index f3061f7d96..23512c8ce4 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -597,7 +597,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str: for i in range(1, len(raw_results)): spk, txt = raw_results[i] - if spk == None: + if spk is None: merged_results.append((None, current_text)) continue diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index fe103c7117..2106369bd6 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -277,6 +277,22 @@ class Executor: elif self.auth.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key or "" + # Handle Content-Type for multipart/form-data requests + # Fix for issue #22880: Missing boundary when using multipart/form-data + body = self.node_data.body + if body and body.type == "form-data": + # For multipart/form-data with files, let httpx handle the boundary automatically + # by not setting Content-Type header when files are present + if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files): + # Only set Content-Type when there are no actual files + # This ensures httpx generates the correct boundary + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = "multipart/form-data" + elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: + # Set Content-Type for other body types + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + return headers def _validate_and_parse_response(self, response: httpx.Response) -> Response: @@ -384,15 +400,24 @@ class Executor: # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. # This prevents logging meaningless placeholder entries. if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for key, (filename, content, mime_type) in self.files: + for file_entry in self.files: + # file_entry should be (key, (filename, content, mime_type)), but handle edge cases + if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2: + continue # skip malformed entries + key = file_entry[0] + content = file_entry[1][1] body_string += f"--{boundary}\r\n" body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content - try: - body_string += content.decode("utf-8") - except UnicodeDecodeError: - # fix: decode binary content - pass + # decode content safely + if isinstance(content, bytes): + try: + body_string += content.decode("utf-8") + except UnicodeDecodeError: + body_string += content.decode("utf-8", errors="replace") + elif isinstance(content, str): + body_string += content + else: + body_string += f"[Unsupported content type: {type(content).__name__}]" body_string += "\r\n" body_string += f"--{boundary}--\r\n" elif self.node_data.body: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index f1767bdf9e..b71271abeb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -74,6 +74,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e041e217ca..7303b68501 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -602,6 +602,28 @@ class KnowledgeRetrievalNode(BaseNode): **{key: metadata_name, key_value: f"%{value}"} ) ) + case "in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) + case "not in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) case "=" | "is": if isinstance(value, str): filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 90a0397b67..dfc2a0000b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,7 +3,7 @@ import io import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.entities.model_entities import ( - AIModelEntity, ModelFeature, ModelPropertyKey, ModelType, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -1006,21 +1004,6 @@ class LLMNode(BaseNode): ) return saved_file - def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: - """ - Fetch model schema - """ - model_name = self._node_data.model.name - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name - ) - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_credentials = model_instance.credentials - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - return model_schema - @staticmethod def fetch_structured_output_schema( *, diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 512a9cb608..b2bcee5dcd 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,4 +1,6 @@ import mimetypes +import os +import urllib.parse import uuid from collections.abc import Callable, Mapping, Sequence from typing import Any, cast @@ -240,16 +242,21 @@ def _build_from_remote_url( def _get_remote_file_info(url: str): file_size = -1 - filename = url.split("/")[-1].split("?")[0] or "unknown_file" - mime_type = mimetypes.guess_type(filename)[0] or "" + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + + # Initialize mime_type from filename as fallback + mime_type, _ = mimetypes.guess_type(filename) resp = ssrf_proxy.head(url, follow_redirects=True) resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) + # Re-guess mime_type from updated filename + mime_type, _ = mimetypes.guess_type(filename) file_size = int(resp.headers.get("Content-Length", file_size)) - mime_type = mime_type or str(resp.headers.get("Content-Type", "")) return mime_type, filename, file_size diff --git a/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py b/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py new file mode 100644 index 0000000000..1664fb99c4 --- /dev/null +++ b/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py @@ -0,0 +1,25 @@ +"""manual dataset field update + +Revision ID: 532b3f888abf +Revises: 8bcc02c9bd07 +Create Date: 2025-07-24 14:50:48.779833 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '532b3f888abf' +down_revision = '8bcc02c9bd07' +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying") + + +def downgrade(): + op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'") diff --git a/api/models/account.py b/api/models/account.py index d63c5d7fb5..3437055893 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, cast from flask_login import UserMixin # type: ignore -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -86,23 +86,21 @@ class Account(UserMixin, Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) - email: Mapped[str] = mapped_column(db.String(255)) - password: Mapped[Optional[str]] = mapped_column(db.String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(db.String(255)) - avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(db.String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(db.String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column( - db.DateTime, server_default=func.current_timestamp(), nullable=False - ) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + name: Mapped[str] = mapped_column(String(255)) + email: Mapped[str] = mapped_column(String(255)) + password: Mapped[Optional[str]] = mapped_column(String(255)) + password_salt: Mapped[Optional[str]] = mapped_column(String(255)) + avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + interface_language: Mapped[Optional[str]] = mapped_column(String(255)) + interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + timezone: Mapped[Optional[str]] = mapped_column(String(255)) + last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + status: Mapped[str] = mapped_column(String(16), server_default=db.text("'active'::character varying")) + initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): @@ -200,13 +198,13 @@ class Tenant(Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) encrypt_public_key = db.Column(db.Text) - plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + plan: Mapped[str] = mapped_column(String(255), server_default=db.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) custom_config: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -237,10 +235,10 @@ class TenantAccountJoin(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - role: Mapped[str] = mapped_column(db.String(16), server_default="normal") + role: Mapped[str] = mapped_column(String(16), server_default="normal") invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class AccountIntegrate(Base): @@ -253,11 +251,11 @@ class AccountIntegrate(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id: Mapped[str] = mapped_column(StringUUID) - provider: Mapped[str] = mapped_column(db.String(16)) - open_id: Mapped[str] = mapped_column(db.String(255)) - encrypted_token: Mapped[str] = mapped_column(db.String(255)) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + provider: Mapped[str] = mapped_column(String(16)) + open_id: Mapped[str] = mapped_column(String(255)) + encrypted_token: Mapped[str] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class InvitationCode(Base): @@ -269,14 +267,14 @@ class InvitationCode(Base): ) id: Mapped[int] = mapped_column(db.Integer) - batch: Mapped[str] = mapped_column(db.String(255)) - code: Mapped[str] = mapped_column(db.String(32)) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + batch: Mapped[str] = mapped_column(String(255)) + code: Mapped[str] = mapped_column(String(32)) + status: Mapped[str] = mapped_column(String(16), server_default=db.text("'unused'::character varying")) + used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) + deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -298,10 +296,8 @@ class TenantPluginPermission(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column( - db.String(16), nullable=False, server_default="everyone" - ) - debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") + debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") class TenantPluginAutoUpgradeStrategy(Base): @@ -323,14 +319,10 @@ class TenantPluginAutoUpgradeStrategy(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only") + strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") + exclude_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + include_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 3cef5a0fb2..ac9eda6829 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum +from datetime import datetime -from sqlalchemy import func -from sqlalchemy.orm import mapped_column +from sqlalchemy import DateTime, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -24,7 +25,7 @@ class APIBasedExtension(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - api_endpoint = mapped_column(db.String(255), nullable=False) - api_key = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + name: Mapped[str] = mapped_column(String(255), nullable=False) + api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False) + api_key = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 01372f8bf6..e62101ae73 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,7 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -48,22 +48,22 @@ class Dataset(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) description = mapped_column(db.Text, nullable=True) - provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying")) - permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying")) - data_source_type = mapped_column(db.String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255)) + provider: Mapped[str] = mapped_column(String(255), server_default=db.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(String(255), server_default=db.text("'only_me'::character varying")) + data_source_type = mapped_column(String(255)) + indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) index_struct = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column - embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(String(255), nullable=True) # TODO: mapped_column + embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + built_in_field_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def dataset_keyword_table(self): @@ -268,10 +268,10 @@ class DatasetProcessRule(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + mode = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -313,61 +313,59 @@ class Document(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) - data_source_type = mapped_column(db.String(255), nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) + data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) data_source_info = mapped_column(db.Text, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) - batch = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_from = mapped_column(db.String(255), nullable=False) + batch: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at = mapped_column(db.DateTime, nullable=True) + processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # parsing file_id = mapped_column(db.Text, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - parsing_completed_at = mapped_column(db.DateTime, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at = mapped_column(db.DateTime, nullable=True) + cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at = mapped_column(db.DateTime, nullable=True) + splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # indexing - tokens = mapped_column(db.Integer, nullable=True) - indexing_latency = mapped_column(db.Float, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + indexing_latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # pause - is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + is_paused: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at = mapped_column(db.DateTime, nullable=True) + paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # error error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column( - db.String(255), nullable=False, server_default=db.text("'waiting'::character varying") - ) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + indexing_status = mapped_column(String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - archived_reason = mapped_column(db.String(255), nullable=True) + archived: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at = mapped_column(db.DateTime, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - doc_type = mapped_column(db.String(40), nullable=True) + archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) - doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) - doc_language = mapped_column(db.String(255), nullable=True) + doc_form = mapped_column(String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_language = mapped_column(String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -524,7 +522,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.upload_date, "type": "time", - "value": self.created_at.timestamp(), + "value": str(self.created_at.timestamp()), } ) built_in_fields.append( @@ -532,7 +530,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.last_update_date, "type": "time", - "value": self.updated_at.timestamp(), + "value": str(self.updated_at.timestamp()), } ) built_in_fields.append( @@ -667,23 +665,23 @@ class DocumentSegment(Base): # indexing fields keywords = mapped_column(db.JSON, nullable=True) - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) # basic fields - hit_count = mapped_column(db.Integer, nullable=False, default=0) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -808,19 +806,23 @@ class ChildChunk(Base): dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) segment_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) content = mapped_column(db.Text, nullable=False) - word_count = mapped_column(db.Integer, nullable=False) + word_count: Mapped[int] = mapped_column(db.Integer, nullable=False) # indexing fields - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) - type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) + type = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) error = mapped_column(db.Text, nullable=True) @property @@ -846,7 +848,7 @@ class AppDatasetJoin(Base): id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def app(self): @@ -863,11 +865,11 @@ class DatasetQuery(Base): id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) content = mapped_column(db.Text, nullable=False) - source = mapped_column(db.String(255), nullable=False) + source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id = mapped_column(StringUUID, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) class DatasetKeywordTable(Base): @@ -881,7 +883,7 @@ class DatasetKeywordTable(Base): dataset_id = mapped_column(StringUUID, nullable=False, unique=True) keyword_table = mapped_column(db.Text, nullable=False) data_source_type = mapped_column( - db.String(255), nullable=False, server_default=db.text("'database'::character varying") + String(255), nullable=False, server_default=db.text("'database'::character varying") ) @property @@ -925,12 +927,12 @@ class Embedding(Base): id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) model_name = mapped_column( - db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") ) - hash = mapped_column(db.String(64), nullable=False) + hash = mapped_column(String(64), nullable=False) embedding = mapped_column(db.LargeBinary, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = mapped_column(String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -947,11 +949,11 @@ class DatasetCollectionBinding(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - provider_name = mapped_column(db.String(255), nullable=False) - model_name = mapped_column(db.String(255), nullable=False) - type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) - collection_name = mapped_column(db.String(64), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + type = mapped_column(String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + collection_name = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(Base): @@ -965,13 +967,13 @@ class TidbAuthBinding(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - cluster_id = mapped_column(db.String(255), nullable=False) - cluster_name = mapped_column(db.String(255), nullable=False) - active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING")) - account = mapped_column(db.String(255), nullable=False) - password = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) + cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) + active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying")) + account: Mapped[str] = mapped_column(String(255), nullable=False) + password: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(Base): @@ -982,8 +984,8 @@ class Whitelist(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - category = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(Base): @@ -999,8 +1001,8 @@ class DatasetPermission(Base): dataset_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + has_permission: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): @@ -1012,14 +1014,14 @@ class ExternalKnowledgeApis(Base): ) id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) settings = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -1072,9 +1074,9 @@ class ExternalKnowledgeBindings(Base): dataset_id = mapped_column(StringUUID, nullable=False) external_knowledge_id = mapped_column(db.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetAutoDisableLog(Base): @@ -1090,8 +1092,10 @@ class DatasetAutoDisableLog(Base): tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + notified: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class RateLimitLog(Base): @@ -1104,9 +1108,11 @@ class RateLimitLog(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - subscription_plan = mapped_column(db.String(255), nullable=False) - operation = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False) + operation: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class DatasetMetadata(Base): @@ -1120,10 +1126,14 @@ class DatasetMetadata(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) created_by = mapped_column(StringUUID, nullable=False) updated_by = mapped_column(StringUUID, nullable=True) @@ -1143,5 +1153,5 @@ class DatasetMetadataBinding(Base): dataset_id = mapped_column(StringUUID, nullable=False) metadata_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 9f6d51b315..fba0d692eb 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: import sqlalchemy as sa from flask import request from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text +from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -37,7 +37,7 @@ class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = mapped_column(db.String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -73,15 +73,15 @@ class App(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) - mode: Mapped[str] = mapped_column(db.String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji - icon = db.Column(db.String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(db.String(255)) + mode: Mapped[str] = mapped_column(String(255)) + icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon = db.Column(String(255)) + icon_background: Mapped[Optional[str]] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) enable_site: Mapped[bool] = mapped_column(db.Boolean) enable_api: Mapped[bool] = mapped_column(db.Boolean) api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) @@ -306,8 +306,8 @@ class AppModelConfig(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) + provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) configs = mapped_column(db.JSON, nullable=True) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -321,12 +321,12 @@ class AppModelConfig(Base): more_like_this = mapped_column(db.Text) model = mapped_column(db.Text) user_input_form = mapped_column(db.Text) - dataset_query_variable = mapped_column(db.String(255)) + dataset_query_variable = mapped_column(String(255)) pre_prompt = mapped_column(db.Text) agent_mode = mapped_column(db.Text) sensitive_word_avoidance = mapped_column(db.Text) retriever_resource = mapped_column(db.Text) - prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) + prompt_type = mapped_column(String(255), nullable=False, server_default=db.text("'simple'::character varying")) chat_prompt_config = mapped_column(db.Text) completion_prompt_config = mapped_column(db.Text) dataset_configs = mapped_column(db.Text) @@ -561,14 +561,14 @@ class RecommendedApp(Base): id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) description = mapped_column(db.JSON, nullable=False) - copyright = mapped_column(db.String(255), nullable=False) - privacy_policy = mapped_column(db.String(255), nullable=False) + copyright: Mapped[str] = mapped_column(String(255), nullable=False) + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - category = mapped_column(db.String(255), nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_listed = mapped_column(db.Boolean, nullable=False, default=True) - install_count = mapped_column(db.Integer, nullable=False, default=0) - language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) + category: Mapped[str] = mapped_column(String(255), nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + is_listed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=True) + install_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + language = mapped_column(String(255), nullable=False, server_default=db.text("'en-US'::character varying")) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -591,8 +591,8 @@ class InstalledApp(Base): tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) app_owner_tenant_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + is_pinned: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = mapped_column(db.DateTime, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -617,26 +617,26 @@ class Conversation(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) app_model_config_id = mapped_column(StringUUID, nullable=True) - model_provider = mapped_column(db.String(255), nullable=True) + model_provider = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(db.Text) - model_id = mapped_column(db.String(255), nullable=True) - mode: Mapped[str] = mapped_column(db.String(255)) - name = mapped_column(db.String(255), nullable=False) + model_id = mapped_column(String(255), nullable=True) + mode: Mapped[str] = mapped_column(String(255)) + name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) introduction = mapped_column(db.Text) system_instruction = mapped_column(db.Text) - system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - status = mapped_column(db.String(255), nullable=False) + system_instruction_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + status: Mapped[str] = mapped_column(String(255), nullable=False) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(db.String(255), nullable=True) + invoke_from = mapped_column(String(255), nullable=True) # ref: ConversationSource. - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) read_at = mapped_column(db.DateTime) @@ -650,7 +650,7 @@ class Conversation(Base): "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def inputs(self): @@ -894,8 +894,8 @@ class Message(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - model_provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) + model_provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(db.Text) conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -911,17 +911,17 @@ class Message(Base): parent_message_id = mapped_column(StringUUID, nullable=True) provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) error = mapped_column(db.Text) message_metadata = mapped_column(db.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - from_source = mapped_column(db.String(255), nullable=False) + invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + agent_based: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property @@ -1238,9 +1238,9 @@ class MessageFeedback(Base): app_id = mapped_column(StringUUID, nullable=False) conversation_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) - rating = mapped_column(db.String(255), nullable=False) + rating: Mapped[str] = mapped_column(String(255), nullable=False) content = mapped_column(db.Text) - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1298,12 +1298,12 @@ class MessageFile(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1323,7 +1323,7 @@ class MessageAnnotation(Base): message_id: Mapped[Optional[str]] = mapped_column(StringUUID) question = db.Column(db.Text, nullable=True) content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) account_id = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1415,10 +1415,10 @@ class OperationLog(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - action = mapped_column(db.String(255), nullable=False) + action: Mapped[str] = mapped_column(String(255), nullable=False) content = mapped_column(db.JSON) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_ip = mapped_column(db.String(255), nullable=False) + created_ip: Mapped[str] = mapped_column(String(255), nullable=False) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1433,10 +1433,10 @@ class EndUser(Base, UserMixin): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(255), nullable=False) - external_user_id = mapped_column(db.String(255), nullable=True) - name = mapped_column(db.String(255)) - is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + external_user_id = mapped_column(String(255), nullable=True) + name = mapped_column(String(255)) + is_anonymous: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) session_id: Mapped[str] = mapped_column() created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1452,10 +1452,10 @@ class AppMCPServer(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) - server_code = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) + server_code: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) parameters = mapped_column(db.Text, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1485,28 +1485,28 @@ class Site(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - title = mapped_column(db.String(255), nullable=False) - icon_type = mapped_column(db.String(255), nullable=True) - icon = mapped_column(db.String(255)) - icon_background = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(String(255), nullable=False) + icon_type = mapped_column(String(255), nullable=True) + icon = mapped_column(String(255)) + icon_background = mapped_column(String(255)) description = mapped_column(db.Text) - default_language = mapped_column(db.String(255), nullable=False) - chat_color_theme = mapped_column(db.String(255)) - chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - copyright = mapped_column(db.String(255)) - privacy_policy = mapped_column(db.String(255)) - show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + default_language: Mapped[str] = mapped_column(String(255), nullable=False) + chat_color_theme = mapped_column(String(255)) + chat_color_theme_inverted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + copyright = mapped_column(String(255)) + privacy_policy = mapped_column(String(255)) + show_workflow_steps: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") - customize_domain = mapped_column(db.String(255)) - customize_token_strategy = mapped_column(db.String(255), nullable=False) - prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + customize_domain = mapped_column(String(255)) + customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + prompt_public: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - code = mapped_column(db.String(255)) + code = mapped_column(String(255)) @property def custom_disclaimer(self): @@ -1544,8 +1544,8 @@ class ApiToken(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - token = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(db.DateTime, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1567,21 +1567,21 @@ class UploadFile(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False) - key: Mapped[str] = mapped_column(db.String(255), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + storage_type: Mapped[str] = mapped_column(String(255), nullable=False) + key: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) size: Mapped[int] = mapped_column(db.Integer, nullable=False) - extension: Mapped[str] = mapped_column(db.String(255), nullable=False) - mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True) + extension: Mapped[str] = mapped_column(String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(String(255), nullable=True) created_by_role: Mapped[str] = mapped_column( - db.String(255), nullable=False, server_default=db.text("'account'::character varying") + String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( @@ -1630,10 +1630,10 @@ class ApiRequest(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) api_token_id = mapped_column(StringUUID, nullable=False) - path = mapped_column(db.String(255), nullable=False) + path: Mapped[str] = mapped_column(String(255), nullable=False) request = mapped_column(db.Text, nullable=True) response = mapped_column(db.Text, nullable=True) - ip = mapped_column(db.String(255), nullable=False) + ip: Mapped[str] = mapped_column(String(255), nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1646,7 +1646,7 @@ class MessageChain(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) input = mapped_column(db.Text, nullable=True) output = mapped_column(db.Text, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1663,7 +1663,7 @@ class MessageAgentThought(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) message_chain_id = mapped_column(StringUUID, nullable=True) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) thought = mapped_column(db.Text, nullable=True) tool = mapped_column(db.Text, nullable=True) tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) @@ -1673,19 +1673,19 @@ class MessageAgentThought(Base): # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design tool_process_data = mapped_column(db.Text, nullable=True) message = mapped_column(db.Text, nullable=True) - message_token = mapped_column(db.Integer, nullable=True) + message_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) message_unit_price = mapped_column(db.Numeric, nullable=True) message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_files = mapped_column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) - answer_token = mapped_column(db.Integer, nullable=True) + answer_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) answer_unit_price = mapped_column(db.Numeric, nullable=True) answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens = mapped_column(db.Integer, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) total_price = mapped_column(db.Numeric, nullable=True) - currency = mapped_column(db.String, nullable=True) - latency = mapped_column(db.Float, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + currency = mapped_column(String, nullable=True) + latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1775,18 +1775,18 @@ class DatasetRetrieverResource(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) dataset_name = mapped_column(db.Text, nullable=False) document_id = mapped_column(StringUUID, nullable=True) document_name = mapped_column(db.Text, nullable=False) data_source_type = mapped_column(db.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score = mapped_column(db.Float, nullable=True) + score: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - segment_position = mapped_column(db.Integer, nullable=True) + hit_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + segment_position: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) index_node_hash = mapped_column(db.Text, nullable=True) retriever_from = mapped_column(db.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) @@ -1805,8 +1805,8 @@ class Tag(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - name = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1836,13 +1836,13 @@ class TraceAppConfig(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - tracing_provider = mapped_column(db.String(255), nullable=True) + tracing_provider = mapped_column(String(255), nullable=True) tracing_config = mapped_column(db.JSON, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/provider.py b/api/models/provider.py index 1e25f0c90f..7bfc249b0b 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from typing import Optional -from sqlalchemy import func, text +from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -56,22 +56,22 @@ class Provider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) provider_type: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'custom'::character varying") + String(40), nullable=False, server_default=text("'custom'::character varying") ) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( - db.String(40), nullable=True, server_default=text("''::character varying") + String(40), nullable=True, server_default=text("''::character varying") ) quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -113,13 +113,13 @@ class ProviderModel(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -131,11 +131,11 @@ class TenantDefaultModel(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -147,10 +147,10 @@ class TenantPreferredModelProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -162,22 +162,22 @@ class ProviderOrder(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + currency: Mapped[Optional[str]] = mapped_column(String(40)) total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) payment_status: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -193,13 +193,13 @@ class ProviderModelSetting(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -215,11 +215,11 @@ class LoadBalancingModelConfig(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 100e0d96ef..8191c874a4 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,8 +1,10 @@ import json +from datetime import datetime +from typing import Optional -from sqlalchemy import func +from sqlalchemy import DateTime, String, func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -20,12 +22,12 @@ class DataSourceOauthBinding(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - access_token = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + access_token: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) source_info = mapped_column(JSONB, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -38,12 +40,12 @@ class DataSourceApiKeyAuthBinding(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - category = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + category: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) credentials = mapped_column(db.Text, nullable=True) # JSON - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 3e5ebd2099..66a47ea4df 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Optional from celery import states # type: ignore +from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now @@ -16,22 +17,22 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) - task_id = mapped_column(db.String(155), unique=True) - status = mapped_column(db.String(50), default=states.PENDING) + task_id = mapped_column(String(155), unique=True) + status = mapped_column(String(50), default=states.PENDING) result = mapped_column(db.PickleType, nullable=True) date_done = mapped_column( - db.DateTime, + DateTime, default=lambda: naive_utc_now(), onupdate=lambda: naive_utc_now(), nullable=True, ) traceback = mapped_column(db.Text, nullable=True) - name = mapped_column(db.String(155), nullable=True) + name = mapped_column(String(155), nullable=True) args = mapped_column(db.LargeBinary, nullable=True) kwargs = mapped_column(db.LargeBinary, nullable=True) - worker = mapped_column(db.String(155), nullable=True) - retries = mapped_column(db.Integer, nullable=True) - queue = mapped_column(db.String(155), nullable=True) + worker = mapped_column(String(155), nullable=True) + retries: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + queue = mapped_column(String(155), nullable=True) class CeleryTaskSet(Base): @@ -42,6 +43,6 @@ class CeleryTaskSet(Base): id: Mapped[int] = mapped_column( db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) - taskset_id = mapped_column(db.String(155), unique=True) + taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 68f4211e59..1491cd90ce 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, func +from sqlalchemy import ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column from core.file import helpers as file_helpers @@ -30,8 +30,8 @@ class ToolOAuthSystemClient(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -47,8 +47,8 @@ class ToolOAuthTenantClient(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -72,26 +72,26 @@ class BuiltinToolProvider(Base): # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column( - db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # name of the tool provider - provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + provider: Mapped[str] = mapped_column(String(256), nullable=False) # credential of the tool provider encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + String(32), nullable=False, server_default=db.text("'api-key'::character varying") ) expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) @@ -113,12 +113,12 @@ class ApiToolProvider(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) + name = mapped_column(String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon - icon = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema = mapped_column(db.Text, nullable=False) - schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False) + schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) # who created this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id @@ -130,12 +130,12 @@ class ApiToolProvider(Base): # json format credentials credentials_str = mapped_column(db.Text, nullable=False) # privacy policy - privacy_policy = mapped_column(db.String(255), nullable=True) + privacy_policy = mapped_column(String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -173,11 +173,11 @@ class ToolLabelBinding(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id - tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) + tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + tool_type: Mapped[str] = mapped_column(String(40), nullable=False) # label name - label_name: Mapped[str] = mapped_column(db.String(40), nullable=False) + label_name: Mapped[str] = mapped_column(String(40), nullable=False) class WorkflowToolProvider(Base): @@ -194,15 +194,15 @@ class WorkflowToolProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider - label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # icon - icon: Mapped[str] = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # app id of the workflow provider app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # version of the workflow provider - version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id @@ -212,13 +212,13 @@ class WorkflowToolProvider(Base): # parameter configuration parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) @property @@ -253,15 +253,15 @@ class MCPToolProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the mcp provider - name: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_identifier: Mapped[str] = mapped_column(String(64), nullable=False) # encrypted url of the mcp provider server_url: Mapped[str] = mapped_column(db.Text, nullable=False) # hash of server_url for uniqueness check - server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False) # icon of the mcp provider - icon: Mapped[str] = mapped_column(db.String(255), nullable=True) + icon: Mapped[str] = mapped_column(String(255), nullable=True) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who created this tool @@ -273,10 +273,10 @@ class MCPToolProvider(Base): # tools tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) def load_user(self) -> Account | None: @@ -355,11 +355,11 @@ class ToolModelInvoke(Base): # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # provider - provider = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type = mapped_column(db.String(40), nullable=False) + tool_type = mapped_column(String(40), nullable=False) # tool name - tool_name = mapped_column(db.String(128), nullable=False) + tool_name = mapped_column(String(128), nullable=False) # invoke parameters model_parameters = mapped_column(db.Text, nullable=False) # prompt messages @@ -367,15 +367,15 @@ class ToolModelInvoke(Base): # invoke response model_response = mapped_column(db.Text, nullable=False) - prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + prompt_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -402,8 +402,8 @@ class ToolConversationVariables(Base): # variables pool variables_str = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> Any: @@ -429,11 +429,11 @@ class ToolFile(Base): # conversation id conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # file key - file_key: Mapped[str] = mapped_column(db.String(255), nullable=False) + file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type - mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) + mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + original_url: Mapped[str] = mapped_column(String(2048), nullable=True) # name name: Mapped[str] = mapped_column(default="") # size @@ -465,13 +465,13 @@ class DeprecatedPublishedAppTool(Base): # to describe this parameter to llm, we need this field query_description = mapped_column(db.Text, nullable=False) # query name, the name of the query parameter - query_name = mapped_column(db.String(40), nullable=False) + query_name = mapped_column(String(40), nullable=False) # name of the tool provider - tool_name = mapped_column(db.String(40), nullable=False) + tool_name = mapped_column(String(40), nullable=False) # author - author = mapped_column(db.String(40), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + author = mapped_column(String(40), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index ce00f4010f..1bf9b5c761 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,6 @@ -from sqlalchemy import func +from datetime import datetime + +from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -19,10 +21,10 @@ class SavedMessage(Base): app_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=db.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -40,7 +42,7 @@ class PinnedConversation(Base): app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=db.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index d89db6c7da..6c7d061bb4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 from flask_login import current_user -from sqlalchemy import orm +from sqlalchemy import DateTime, orm from core.file.constants import maybe_file_object from core.file.models import File @@ -25,7 +25,7 @@ if TYPE_CHECKING: from models.model import AppMode import sqlalchemy as sa -from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func +from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, declared_attr, mapped_column from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE @@ -124,17 +124,17 @@ class Workflow(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - version: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(default="", server_default="") marked_comment: Mapped[str] = mapped_column(default="", server_default="") graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=naive_utc_now(), server_onupdate=func.current_timestamp(), @@ -500,21 +500,21 @@ class WorkflowRun(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(db.String(255)) - triggered_from: Mapped[str] = mapped_column(db.String(255)) - version: Mapped[str] = mapped_column(db.String(255)) + type: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) + version: Mapped[str] = mapped_column(String(255)) graph: Mapped[Optional[str]] = mapped_column(db.Text) inputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded + status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user + created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) @property @@ -708,25 +708,25 @@ class WorkflowNodeExecutionModel(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(db.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_id: Mapped[str] = mapped_column(db.String(255)) - node_type: Mapped[str] = mapped_column(db.String(255)) - title: Mapped[str] = mapped_column(db.String(255)) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_id: Mapped[str] = mapped_column(String(255)) + node_type: Mapped[str] = mapped_column(String(255)) + title: Mapped[str] = mapped_column(String(255)) inputs: Mapped[Optional[str]] = mapped_column(db.Text) process_data: Mapped[Optional[str]] = mapped_column(db.Text) outputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) + status: Mapped[str] = mapped_column(String(255)) error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) @property def created_by_account(self): @@ -843,10 +843,10 @@ class WorkflowAppLog(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -873,10 +873,10 @@ class ConversationVariable(Base): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) data: Mapped[str] = mapped_column(db.Text, nullable=False) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True + DateTime, nullable=False, server_default=func.current_timestamp(), index=True ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: @@ -936,14 +936,14 @@ class WorkflowDraftVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) created_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), @@ -958,7 +958,7 @@ class WorkflowDraftVariable(Base): # # If it's not edited after creation, its value is `None`. last_edited_at: Mapped[datetime | None] = mapped_column( - db.DateTime, + DateTime, nullable=True, default=None, ) diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index a05e1358ed..f0d3bed057 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -1,8 +1,8 @@ import logging from datetime import datetime -from urllib.parse import urlparse import click +from kombu.utils.url import parse_url # type: ignore from redis import Redis import app @@ -10,16 +10,13 @@ from configs import dify_config from extensions.ext_database import db from libs.email_i18n import EmailType, get_email_i18n_service -# Create a dedicated Redis connection (using the same configuration as Celery) -celery_broker_url = dify_config.CELERY_BROKER_URL - -parsed = urlparse(celery_broker_url) -host = parsed.hostname or "localhost" -port = parsed.port or 6379 -password = parsed.password or None -redis_db = parsed.path.strip("/") or "1" # type: ignore - -celery_redis = Redis(host=host, port=port, password=password, db=redis_db) +redis_config = parse_url(dify_config.CELERY_BROKER_URL) +celery_redis = Redis( + host=redis_config.get("hostname") or "localhost", + port=redis_config.get("port") or 6379, + password=redis_config.get("password") or None, + db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, +) @app.celery.task(queue="monitor") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 206c832a20..692a3639cd 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,12 +1,15 @@ from collections.abc import Callable, Sequence -from typing import Optional, Union +from typing import Any, Optional, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator +from core.variables.types import SegmentType +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db +from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ConversationVariable @@ -15,6 +18,7 @@ from models.model import App, Conversation, EndUser, Message from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, LastConversationNotExistsError, ) from services.errors.message import MessageNotExistsError @@ -220,3 +224,82 @@ class ConversationService: ] return InfiniteScrollPagination(variables, limit, has_more) + + @classmethod + def update_conversation_variable( + cls, + app_model: App, + conversation_id: str, + variable_id: str, + user: Optional[Union[Account, EndUser]], + new_value: Any, + ) -> dict: + """ + Update a conversation variable's value. + + Args: + app_model: The app model + conversation_id: The conversation ID + variable_id: The variable ID to update + user: The user (Account or EndUser) + new_value: The new value for the variable + + Returns: + Dictionary containing the updated variable information + + Raises: + ConversationNotExistsError: If the conversation doesn't exist + ConversationVariableNotExistsError: If the variable doesn't exist + ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type + """ + # Verify conversation exists and user has access + conversation = cls.get_conversation(app_model, conversation_id, user) + + # Get the existing conversation variable + stmt = ( + select(ConversationVariable) + .where(ConversationVariable.app_id == app_model.id) + .where(ConversationVariable.conversation_id == conversation.id) + .where(ConversationVariable.id == variable_id) + ) + + with Session(db.engine) as session: + existing_variable = session.scalar(stmt) + if not existing_variable: + raise ConversationVariableNotExistsError() + + # Convert existing variable to Variable object + current_variable = existing_variable.to_variable() + + # Validate that the new value type matches the expected variable type + expected_type = SegmentType(current_variable.value_type) + if not expected_type.is_valid(new_value): + inferred_type = SegmentType.infer_segment_type(new_value) + raise ConversationVariableTypeMismatchError( + f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, " + f"but got {inferred_type.value if inferred_type else 'unknown'} type" + ) + + # Create updated variable with new value only, preserving everything else + updated_variable_dict = { + "id": current_variable.id, + "name": current_variable.name, + "description": current_variable.description, + "value_type": current_variable.value_type, + "value": new_value, + "selector": current_variable.selector, + } + + updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) + + # Use the conversation variable updater to persist the changes + updater = conversation_variable_updater_factory() + updater.update(conversation_id, updated_variable) + updater.flush() + + # Return the updated variable data + return { + "created_at": existing_variable.created_at, + "updated_at": naive_utc_now(), # Update timestamp + **updated_variable.model_dump(), + } diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 1280399990..da475a18f8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2040,6 +2040,7 @@ class SegmentService: db.session.add(segment_document) # update document word count + assert document.word_count is not None document.word_count += segment_document.word_count db.session.add(document) db.session.commit() @@ -2124,6 +2125,7 @@ class SegmentService: else: keywords_list.append(None) # update document word count + assert document.word_count is not None document.word_count += increment_word_count db.session.add(document) try: @@ -2185,6 +2187,7 @@ class SegmentService: db.session.commit() # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task @@ -2260,6 +2263,7 @@ class SegmentService: word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) db.session.add(segment) @@ -2323,6 +2327,7 @@ class SegmentService: delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count + assert document.word_count is not None document.word_count -= segment.word_count db.session.add(document) db.session.commit() diff --git a/api/services/errors/conversation.py b/api/services/errors/conversation.py index f8051e3417..a123f99b59 100644 --- a/api/services/errors/conversation.py +++ b/api/services/errors/conversation.py @@ -15,3 +15,7 @@ class ConversationCompletedError(Exception): class ConversationVariableNotExistsError(BaseServiceError): pass + + +class ConversationVariableTypeMismatchError(BaseServiceError): + pass diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 841eeb4333..da0fc58566 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -508,10 +508,10 @@ class BuiltinToolManageService: oauth_params = encrypter.decrypt(user_client.oauth_params) return oauth_params - # only verified provider can use custom oauth client - is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( - tenant_id, provider.plugin_unique_identifier - ) + # only verified provider can use official oauth client + is_verified = not isinstance( + provider_controller, PluginToolProviderController + ) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if not is_verified: return oauth_params diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 714e30acc3..dee43cd854 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -134,6 +134,7 @@ def batch_create_segment_to_index_task( db.session.add(segment_document) document_segments.append(segment_document) # update document word count + assert dataset_document.word_count is not None dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py index da549af1b6..aebf3fbda1 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py @@ -2,6 +2,7 @@ import os import uuid import tablestore +from _pytest.python_api import approx from core.rag.datasource.vdb.tablestore.tablestore_vector import ( TableStoreConfig, @@ -16,7 +17,7 @@ from tests.integration_tests.vdb.test_vector_store import ( class TableStoreVectorTest(AbstractVectorTest): - def __init__(self): + def __init__(self, normalize_full_text_score: bool = False): super().__init__() self.vector = TableStoreVector( collection_name=self.collection_name, @@ -25,6 +26,7 @@ class TableStoreVectorTest(AbstractVectorTest): instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"), access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"), access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"), + normalize_full_text_bm25_score=normalize_full_text_score, ), ) @@ -64,7 +66,21 @@ class TableStoreVectorTest(AbstractVectorTest): docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id]) assert len(docs) == 1 assert docs[0].metadata["doc_id"] == self.example_doc_id - assert not hasattr(docs[0], "score") + if self.vector._config.normalize_full_text_bm25_score: + assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3) + else: + assert docs[0].metadata.get("score") is None + + # return none if normalize_full_text_score=true and score_threshold > 0 + docs = self.vector.search_by_full_text( + get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5 + ) + if self.vector._config.normalize_full_text_bm25_score: + assert len(docs) == 0 + else: + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert docs[0].metadata.get("score") is None docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())]) assert len(docs) == 0 @@ -80,3 +96,5 @@ class TableStoreVectorTest(AbstractVectorTest): def test_tablestore_vector(setup_mock_redis): TableStoreVectorTest().run_all_tests() + TableStoreVectorTest(normalize_full_text_score=True).run_all_tests() + TableStoreVectorTest(normalize_full_text_score=False).run_all_tests() diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index e9d4ee1935..0ae6a09f5b 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -1,5 +1,6 @@ import os +import pytest from flask import Flask from packaging.version import Version from yarl import URL @@ -137,3 +138,61 @@ def test_db_extras_options_merging(monkeypatch): options = engine_options["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options + + +@pytest.mark.parametrize( + ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), + [ + ("redis://localhost:6379/1", "localhost", 6379, None, None, "1"), + ("redis://:password@localhost:6379/1", "localhost", 6379, None, "password", "1"), + ("redis://:mypass%23123@localhost:6379/1", "localhost", 6379, None, "mypass#123", "1"), + ("redis://user:pass%40word@redis-host:6380/2", "redis-host", 6380, "user", "pass@word", "2"), + ("redis://admin:complex%23pass%40word@127.0.0.1:6379/0", "127.0.0.1", 6379, "admin", "complex#pass@word", "0"), + ( + "redis://user%40domain:secret%23123@redis.example.com:6380/3", + "redis.example.com", + 6380, + "user@domain", + "secret#123", + "3", + ), + # Password containing %23 substring (double encoding scenario) + ("redis://:mypass%2523@localhost:6379/1", "localhost", 6379, None, "mypass%23", "1"), + # Username and password both containing encoded characters + ("redis://user%2525%40:pass%2523@localhost:6379/1", "localhost", 6379, "user%25@", "pass%23", "1"), + ], +) +def test_celery_broker_url_with_special_chars_password( + monkeypatch, broker_url, expected_host, expected_port, expected_username, expected_password, expected_db +): + """Test that CELERY_BROKER_URL with various formats are handled correctly.""" + from kombu.utils.url import parse_url + + # clear system environment variables + os.environ.clear() + + # Set up basic required environment variables (following existing pattern) + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + # Set the CELERY_BROKER_URL to test + monkeypatch.setenv("CELERY_BROKER_URL", broker_url) + + # Create config and verify the URL is stored correctly + config = DifyConfig() + assert broker_url == config.CELERY_BROKER_URL + + # Test actual parsing behavior using kombu's parse_url (same as production) + redis_config = parse_url(config.CELERY_BROKER_URL) + + # Verify the parsing results match expectations (using kombu's field names) + assert redis_config["hostname"] == expected_host + assert redis_config["port"] == expected_port + assert redis_config["userid"] == expected_username # kombu uses 'userid' not 'username' + assert redis_config["password"] == expected_password + assert redis_config["virtual_host"] == expected_db # kombu uses 'virtual_host' not 'db' diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py new file mode 100644 index 0000000000..cb5562d345 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -0,0 +1,278 @@ +import io +from unittest.mock import patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.common.errors import FilenameNotExistsError +from controllers.console.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from services.errors.file import FileTooLargeError as ServiceFileTooLargeError +from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + +class TestFileUploadSecurity: + """Test file upload security logic without complex framework setup""" + + # Test 1: Basic file validation + def test_should_validate_file_presence(self): + """Test that missing file is detected""" + from flask import Flask, request + + app = Flask(__name__) + + with app.test_request_context(method="POST", data={}): + # Simulate the check in FileApi.post() + if "file" not in request.files: + with pytest.raises(NoFileUploadedError): + raise NoFileUploadedError() + + def test_should_validate_multiple_files(self): + """Test that multiple files are rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = { + "file": (io.BytesIO(b"content1"), "file1.txt", "text/plain"), + "file2": (io.BytesIO(b"content2"), "file2.txt", "text/plain"), + } + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + # Simulate the check in FileApi.post() + if len(request.files) > 1: + with pytest.raises(TooManyFilesError): + raise TooManyFilesError() + + def test_should_validate_empty_filename(self): + """Test that empty filename is rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = {"file": (io.BytesIO(b"content"), "", "text/plain")} + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + file = request.files["file"] + if not file.filename: + with pytest.raises(FilenameNotExistsError): + raise FilenameNotExistsError + + # Test 2: Security - Filename sanitization + def test_should_detect_path_traversal_in_filename(self): + """Test protection against directory traversal attacks""" + dangerous_filenames = [ + "../../../etc/passwd", + "..\\..\\windows\\system32\\config\\sam", + "../../../../etc/shadow", + "./../../../sensitive.txt", + ] + + for filename in dangerous_filenames: + # Any filename containing .. should be considered dangerous + assert ".." in filename, f"Filename {filename} should be detected as path traversal" + + def test_should_detect_null_byte_injection(self): + """Test protection against null byte injection""" + dangerous_filenames = [ + "file.jpg\x00.php", + "document.pdf\x00.exe", + "image.png\x00.sh", + ] + + for filename in dangerous_filenames: + # Null bytes should be detected + assert "\x00" in filename, f"Filename {filename} should be detected as null byte injection" + + def test_should_sanitize_special_characters(self): + """Test that special characters in filenames are handled safely""" + # Characters that could be problematic in various contexts + dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"] + + for char in dangerous_chars: + filename = f"file{char}name.txt" + # These characters should be detected or sanitized + assert any(c in filename for c in dangerous_chars) + + # Test 3: Permission validation + def test_should_validate_dataset_permissions(self): + """Test dataset upload permission logic""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = "datasets" + + # Simulate the permission check in FileApi.post() + if source == "datasets" and not user.is_dataset_editor: + with pytest.raises(Forbidden): + raise Forbidden() + + def test_should_allow_general_upload_without_permission(self): + """Test general upload doesn't require dataset permission""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = None # General upload + + # This should not raise an exception + if source == "datasets" and not user.is_dataset_editor: + raise Forbidden() + # Test passes if no exception is raised + + # Test 4: Service error handling + @patch("services.file_service.FileService.upload_file") + def test_should_handle_file_too_large_error(self, mock_upload): + """Test that service FileTooLargeError is properly converted""" + mock_upload.side_effect = ServiceFileTooLargeError("File too large") + + try: + mock_upload(filename="test.txt", content=b"data", mimetype="text/plain", user=None, source=None) + except ServiceFileTooLargeError as e: + # Simulate the error conversion in FileApi.post() + with pytest.raises(FileTooLargeError): + raise FileTooLargeError(e.description) + + @patch("services.file_service.FileService.upload_file") + def test_should_handle_unsupported_file_type_error(self, mock_upload): + """Test that service UnsupportedFileTypeError is properly converted""" + mock_upload.side_effect = ServiceUnsupportedFileTypeError() + + try: + mock_upload( + filename="test.exe", content=b"data", mimetype="application/octet-stream", user=None, source=None + ) + except ServiceUnsupportedFileTypeError: + # Simulate the error conversion in FileApi.post() + with pytest.raises(UnsupportedFileTypeError): + raise UnsupportedFileTypeError() + + # Test 5: File type security + def test_should_identify_dangerous_file_extensions(self): + """Test detection of potentially dangerous file extensions""" + dangerous_extensions = [ + ".php", + ".PHP", + ".pHp", # PHP files (case variations) + ".exe", + ".EXE", # Executables + ".sh", + ".SH", # Shell scripts + ".bat", + ".BAT", # Batch files + ".cmd", + ".CMD", # Command files + ".ps1", + ".PS1", # PowerShell + ".jar", + ".JAR", # Java archives + ".vbs", + ".VBS", # VBScript + ] + + safe_extensions = [".txt", ".pdf", ".jpg", ".png", ".doc", ".docx"] + + # Just verify our test data is correct + for ext in dangerous_extensions: + assert ext.lower() in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + for ext in safe_extensions: + assert ext.lower() not in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + def test_should_detect_double_extensions(self): + """Test detection of double extension attacks""" + suspicious_filenames = [ + "image.jpg.php", + "document.pdf.exe", + "photo.png.sh", + "file.txt.bat", + ] + + for filename in suspicious_filenames: + # Check that these have multiple extensions + parts = filename.split(".") + assert len(parts) > 2, f"Filename {filename} should have multiple extensions" + + # Test 6: Configuration validation + def test_upload_configuration_structure(self): + """Test that upload configuration has correct structure""" + # Simulate the configuration returned by FileApi.get() + config = { + "file_size_limit": 15, + "batch_count_limit": 5, + "image_file_size_limit": 10, + "video_file_size_limit": 500, + "audio_file_size_limit": 50, + "workflow_file_upload_limit": 10, + } + + # Verify all required fields are present + required_fields = [ + "file_size_limit", + "batch_count_limit", + "image_file_size_limit", + "video_file_size_limit", + "audio_file_size_limit", + "workflow_file_upload_limit", + ] + + for field in required_fields: + assert field in config, f"Missing required field: {field}" + assert isinstance(config[field], int), f"Field {field} should be an integer" + assert config[field] > 0, f"Field {field} should be positive" + + # Test 7: Source parameter handling + def test_source_parameter_normalization(self): + """Test that source parameter is properly normalized""" + test_cases = [ + ("datasets", "datasets"), + ("other", None), + ("", None), + (None, None), + ] + + for input_source, expected in test_cases: + # Simulate the source normalization in FileApi.post() + source = "datasets" if input_source == "datasets" else None + if source not in ("datasets", None): + source = None + assert source == expected + + # Test 8: Boundary conditions + def test_should_handle_edge_case_file_sizes(self): + """Test handling of boundary file sizes""" + test_cases = [ + (0, "Empty file"), # 0 bytes + (1, "Single byte"), # 1 byte + (15 * 1024 * 1024 - 1, "Just under limit"), # Just under 15MB + (15 * 1024 * 1024, "At limit"), # Exactly 15MB + (15 * 1024 * 1024 + 1, "Just over limit"), # Just over 15MB + ] + + for size, description in test_cases: + # Just verify our test data + assert isinstance(size, int), f"{description}: Size should be integer" + assert size >= 0, f"{description}: Size should be non-negative" + + def test_should_handle_special_mime_types(self): + """Test handling of various MIME types""" + mime_type_tests = [ + ("application/octet-stream", "Generic binary"), + ("text/plain", "Plain text"), + ("image/jpeg", "JPEG image"), + ("application/pdf", "PDF document"), + ("", "Empty MIME type"), + (None, "None MIME type"), + ] + + for mime_type, description in mime_type_tests: + # Verify test data structure + if mime_type is not None: + assert isinstance(mime_type, str), f"{description}: MIME type should be string or None" diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 209f8b7c57..1dc380ad0b 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -102,9 +102,14 @@ class TestPhoenixConfig: assert config.project == "default" def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1") - assert config.endpoint == "https://custom.phoenix.com" + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" class TestLangfuseConfig: @@ -368,13 +373,15 @@ class TestConfigIntegration: """Test that URL normalization works consistently across configs""" # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") - phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/") + phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") aliyun_config = AliyunConfig( license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" ) assert arize_config.endpoint == "https://arize.com" - assert phoenix_config.endpoint == "https://phoenix.com" + assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com" assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" def test_project_default_values(self): diff --git a/docker/.env.example b/docker/.env.example index 7ecdf899fe..13cac189aa 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -653,6 +653,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false # ------------------------------ # Knowledge Configuration diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index ae83aa758d..690dccb1a8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -312,6 +312,7 @@ x-shared-env: &shared-api-worker-env TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name} TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx} TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx} + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts index 173aa96118..3bde095f4b 100644 --- a/web/__tests__/check-i18n.test.ts +++ b/web/__tests__/check-i18n.test.ts @@ -49,9 +49,9 @@ describe('check-i18n script functionality', () => { } vm.runInNewContext(transpile(content), context) - const translationObj = moduleExports.default || moduleExports + const translationObj = (context.module.exports as any).default || context.module.exports - if(!translationObj || typeof translationObj !== 'object') + if (!translationObj || typeof translationObj !== 'object') throw new Error(`Error parsing file: ${filePath}`) const nestedKeys: string[] = [] @@ -62,7 +62,7 @@ describe('check-i18n script functionality', () => { // This is an object (but not array), recurse into it but don't add it as a key iterateKeys(obj[key], nestedKey) } - else { + else { // This is a leaf node (string, number, boolean, array, etc.), add it as a key nestedKeys.push(nestedKey) } @@ -73,7 +73,7 @@ describe('check-i18n script functionality', () => { const fileKeys = nestedKeys.map(key => `${camelCaseFileName}.${key}`) allKeys.push(...fileKeys) } - catch (error) { + catch (error) { reject(error) } }) @@ -272,9 +272,6 @@ export default translation const filteredEnKeys = allEnKeys.filter(key => key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), ) - const filteredZhKeys = allZhKeys.filter(key => - key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), - ) expect(allEnKeys).toHaveLength(4) // 2 keys from each file expect(filteredEnKeys).toHaveLength(2) // only components keys diff --git a/web/__tests__/plugin-tool-workflow-error.test.tsx b/web/__tests__/plugin-tool-workflow-error.test.tsx new file mode 100644 index 0000000000..370052bc80 --- /dev/null +++ b/web/__tests__/plugin-tool-workflow-error.test.tsx @@ -0,0 +1,207 @@ +/** + * Test cases to reproduce the plugin tool workflow error + * Issue: #23154 - Application error when loading plugin tools in workflow + * Root cause: split() operation called on null/undefined values + */ + +describe('Plugin Tool Workflow Error Reproduction', () => { + /** + * Mock function to simulate the problematic code in switch-plugin-version.tsx:29 + * const [pluginId] = uniqueIdentifier.split(':') + */ + const mockSwitchPluginVersionLogic = (uniqueIdentifier: string | null | undefined) => { + // This directly reproduces the problematic line from switch-plugin-version.tsx:29 + const [pluginId] = uniqueIdentifier!.split(':') + return pluginId + } + + /** + * Test case 1: Simulate null uniqueIdentifier + * This should reproduce the error mentioned in the issue + */ + it('should reproduce error when uniqueIdentifier is null', () => { + expect(() => { + mockSwitchPluginVersionLogic(null) + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test case 2: Simulate undefined uniqueIdentifier + */ + it('should reproduce error when uniqueIdentifier is undefined', () => { + expect(() => { + mockSwitchPluginVersionLogic(undefined) + }).toThrow('Cannot read properties of undefined (reading \'split\')') + }) + + /** + * Test case 3: Simulate empty string uniqueIdentifier + */ + it('should handle empty string uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('') + expect(result).toBe('') // Empty string split by ':' returns [''] + }).not.toThrow() + }) + + /** + * Test case 4: Simulate malformed uniqueIdentifier without colon separator + */ + it('should handle malformed uniqueIdentifier without colon separator', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('malformed-identifier-without-colon') + expect(result).toBe('malformed-identifier-without-colon') // No colon means full string returned + }).not.toThrow() + }) + + /** + * Test case 5: Simulate valid uniqueIdentifier + */ + it('should work correctly with valid uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('valid-plugin-id:1.0.0') + expect(result).toBe('valid-plugin-id') + }).not.toThrow() + }) +}) + +/** + * Test for the variable processing split error in use-single-run-form-params + */ +describe('Variable Processing Split Error', () => { + /** + * Mock function to simulate the problematic code in use-single-run-form-params.ts:91 + * const getDependentVars = () => { + * return varInputs.map(item => item.variable.slice(1, -1).split('.')) + * } + */ + const mockGetDependentVars = (varInputs: Array<{ variable: string | null | undefined }>) => { + return varInputs.map((item) => { + // Guard against null/undefined variable to prevent app crash + if (!item.variable || typeof item.variable !== 'string') + return [] + + return item.variable.slice(1, -1).split('.') + }).filter(arr => arr.length > 0) // Filter out empty arrays + } + + /** + * Test case 1: Variable processing with null variable + */ + it('should handle null variable safely', () => { + const varInputs = [{ variable: null }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // null variables are filtered out + }) + + /** + * Test case 2: Variable processing with undefined variable + */ + it('should handle undefined variable safely', () => { + const varInputs = [{ variable: undefined }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // undefined variables are filtered out + }) + + /** + * Test case 3: Variable processing with empty string + */ + it('should handle empty string variable', () => { + const varInputs = [{ variable: '' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // Empty string is filtered out, so result is empty array + }) + + /** + * Test case 4: Variable processing with valid variable format + */ + it('should work correctly with valid variable format', () => { + const varInputs = [{ variable: '{{workflow.node.output}}' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result[0]).toEqual(['{workflow', 'node', 'output}']) + }) +}) + +/** + * Integration test to simulate the complete workflow scenario + */ +describe('Plugin Tool Workflow Integration', () => { + /** + * Simulate the scenario where plugin metadata is incomplete or corrupted + * This can happen when: + * 1. Plugin is being loaded from marketplace but metadata request fails + * 2. Plugin configuration is corrupted in database + * 3. Network issues during plugin loading + */ + it('should reproduce the client-side exception scenario', () => { + // Mock incomplete plugin data that could cause the error + const incompletePluginData = { + // Missing or null uniqueIdentifier + uniqueIdentifier: null, + meta: null, + minimum_dify_version: undefined, + } + + // This simulates the error path that leads to the white screen + expect(() => { + // Simulate the code path in switch-plugin-version.tsx:29 + // The actual problematic code doesn't use optional chaining + const _pluginId = (incompletePluginData.uniqueIdentifier as any).split(':')[0] + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test the scenario mentioned in the issue where plugin tools are loaded in workflow + */ + it('should simulate plugin tool loading in workflow context', () => { + // Mock the workflow context where plugin tools are being loaded + const workflowPluginTools = [ + { + provider_name: 'test-plugin', + uniqueIdentifier: null, // This is the problematic case + tool_name: 'test-tool', + }, + { + provider_name: 'valid-plugin', + uniqueIdentifier: 'valid-plugin:1.0.0', + tool_name: 'valid-tool', + }, + ] + + // Process each plugin tool + workflowPluginTools.forEach((tool, _index) => { + if (tool.uniqueIdentifier === null) { + // This reproduces the exact error scenario + expect(() => { + const _pluginId = (tool.uniqueIdentifier as any).split(':')[0] + }).toThrow() + } + else { + // Valid tools should work fine + expect(() => { + const _pluginId = tool.uniqueIdentifier.split(':')[0] + }).not.toThrow() + } + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 58c9f7e5ca..c04d79d2f2 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -271,16 +271,17 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx - { - expand && ( -
-
-
{appDetail.name}
-
-
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
-
- ) - } +
+
+
{appDetail.name}
+
+
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
+
)} diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index b6bfc0e9ac..cf32339b8a 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -124,10 +124,7 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati { !isMobile && (
({ + useSelectedLayoutSegment: () => 'overview', +})) + +// Mock Next.js Link component +jest.mock('next/link', () => { + return function MockLink({ children, href, className, title }: any) { + return ( + + {children} + + ) + } +}) + +// Mock RemixIcon components +const MockIcon = ({ className }: { className?: string }) => ( + +) + +describe('NavLink Text Animation Issues', () => { + const mockProps: NavLinkProps = { + name: 'Orchestrate', + href: '/app/123/workflow', + iconMap: { + selected: MockIcon, + normal: MockIcon, + }, + } + + beforeEach(() => { + // Mock getComputedStyle for transition testing + Object.defineProperty(window, 'getComputedStyle', { + value: jest.fn((element) => { + const isExpanded = element.getAttribute('data-mode') === 'expand' + return { + transition: 'all 0.3s ease', + opacity: isExpanded ? '1' : '0', + width: isExpanded ? 'auto' : '0px', + overflow: 'hidden', + paddingLeft: isExpanded ? '12px' : '10px', // px-3 vs px-2.5 + paddingRight: isExpanded ? '12px' : '10px', + } + }), + writable: true, + }) + }) + + describe('Text Squeeze Animation Issue', () => { + it('should show text squeeze effect when switching from collapse to expand', async () => { + const { rerender } = render() + + // In collapse mode, text should be in DOM but hidden via CSS + const textElement = screen.getByText('Orchestrate') + expect(textElement).toBeInTheDocument() + expect(textElement).toHaveClass('opacity-0') + expect(textElement).toHaveClass('w-0') + expect(textElement).toHaveClass('overflow-hidden') + + // Icon should still be present + expect(screen.getByTestId('nav-icon')).toBeInTheDocument() + + // Check padding in collapse mode + const linkElement = screen.getByTestId('nav-link') + expect(linkElement).toHaveClass('px-2.5') + + // Switch to expand mode - this is where the squeeze effect occurs + rerender() + + // Text should now appear + expect(screen.getByText('Orchestrate')).toBeInTheDocument() + + // Check padding change - this contributes to the squeeze effect + expect(linkElement).toHaveClass('px-3') + + // The bug: text appears abruptly without smooth transition + // This test documents the current behavior that causes the squeeze effect + const expandedTextElement = screen.getByText('Orchestrate') + expect(expandedTextElement).toBeInTheDocument() + + // In a properly animated version, we would expect: + // - Opacity transition from 0 to 1 + // - Width transition from 0 to auto + // - No layout shift from padding changes + }) + + it('should maintain icon position consistency during text appearance', () => { + const { rerender } = render() + + const iconElement = screen.getByTestId('nav-icon') + const initialIconClasses = iconElement.className + + // Icon should have mr-0 in collapse mode + expect(iconElement).toHaveClass('mr-0') + + rerender() + + const expandedIconClasses = iconElement.className + + // Icon should have mr-2 in expand mode - this shift contributes to the squeeze effect + expect(iconElement).toHaveClass('mr-2') + + console.log('Collapsed icon classes:', initialIconClasses) + console.log('Expanded icon classes:', expandedIconClasses) + + // This margin change causes the icon to shift when text appears + }) + + it('should document the abrupt text rendering issue', () => { + const { rerender } = render() + + // Text is present in DOM but hidden via CSS classes + const collapsedText = screen.getByText('Orchestrate') + expect(collapsedText).toBeInTheDocument() + expect(collapsedText).toHaveClass('opacity-0') + expect(collapsedText).toHaveClass('pointer-events-none') + + rerender() + + // Text suddenly appears in DOM - no transition + expect(screen.getByText('Orchestrate')).toBeInTheDocument() + + // The issue: {mode === 'expand' && name} causes abrupt show/hide + // instead of smooth opacity/width transition + }) + }) + + describe('Layout Shift Issues', () => { + it('should detect padding differences causing layout shifts', () => { + const { rerender } = render() + + const linkElement = screen.getByTestId('nav-link') + + // Collapsed state padding + expect(linkElement).toHaveClass('px-2.5') + + rerender() + + // Expanded state padding - different value causes layout shift + expect(linkElement).toHaveClass('px-3') + + // This 2px difference (10px vs 12px) contributes to the squeeze effect + }) + + it('should detect icon margin changes causing shifts', () => { + const { rerender } = render() + + const iconElement = screen.getByTestId('nav-icon') + + // Collapsed: no right margin + expect(iconElement).toHaveClass('mr-0') + + rerender() + + // Expanded: 8px right margin (mr-2) + expect(iconElement).toHaveClass('mr-2') + + // This sudden margin appearance causes the squeeze effect + }) + }) + + describe('Active State Handling', () => { + it('should handle active state correctly in both modes', () => { + // Test non-active state + const { rerender } = render() + + let linkElement = screen.getByTestId('nav-link') + expect(linkElement).not.toHaveClass('bg-state-accent-active') + + // Test with active state (when href matches current segment) + const activeProps = { + ...mockProps, + href: '/app/123/overview', // matches mocked segment + } + + rerender() + + linkElement = screen.getByTestId('nav-link') + expect(linkElement).toHaveClass('bg-state-accent-active') + }) + }) +}) diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index 295b553b04..4607f7b693 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -44,20 +44,29 @@ export default function NavLink({ key={name} href={href} className={classNames( - isActive ? 'bg-state-accent-active text-text-accent font-semibold' : 'text-components-menu-item-text hover:bg-state-base-hover hover:text-components-menu-item-text-hover', - 'group flex items-center h-9 rounded-md py-2 text-sm font-normal', + isActive ? 'bg-state-accent-active font-semibold text-text-accent' : 'text-components-menu-item-text hover:bg-state-base-hover hover:text-components-menu-item-text-hover', + 'group flex h-9 items-center rounded-md py-2 text-sm font-normal', mode === 'expand' ? 'px-3' : 'px-2.5', )} title={mode === 'collapse' ? name : ''} >