diff --git a/README.md b/README.md index 2909e0e6cf..775f6f351f 100644 --- a/README.md +++ b/README.md @@ -235,13 +235,17 @@ Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https:/ One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Deploy to AKS with Azure Devops Pipeline + +One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. -> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). +> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). ## Community & contact diff --git a/README_AR.md b/README_AR.md index e959ca0f78..e7a4dbdb27 100644 --- a/README_AR.md +++ b/README_AR.md @@ -217,13 +217,17 @@ docker compose up -d انشر ​​Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### استخدام Azure Devops Pipeline للنشر على AKS + +انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## المساهمة لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا. في الوقت نفسه، يرجى النظر في دعم Dify عن طريق مشاركته على وسائل التواصل الاجتماعي وفي الفعاليات والمؤتمرات. -> نحن نبحث عن مساهمين لمساعدة في ترجمة Dify إلى لغات أخرى غير اللغة الصينية المندرين أو الإنجليزية. إذا كنت مهتمًا بالمساعدة، يرجى الاطلاع على [README للترجمة](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) لمزيد من المعلومات، واترك لنا تعليقًا في قناة `global-users` على [خادم المجتمع على Discord](https://discord.gg/8Tpq4AcN9c). +> نحن نبحث عن مساهمين لمساعدة في ترجمة Dify إلى لغات أخرى غير اللغة الصينية المندرين أو الإنجليزية. إذا كنت مهتمًا بالمساعدة، يرجى الاطلاع على [README للترجمة](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) لمزيد من المعلومات، واترك لنا تعليقًا في قناة `global-users` على [خادم المجتمع على Discord](https://discord.gg/8Tpq4AcN9c). **المساهمون** diff --git a/README_BN.md b/README_BN.md index 29d7374ea5..e4da437eff 100644 --- a/README_BN.md +++ b/README_BN.md @@ -235,13 +235,17 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) + #### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন + ## Contributing যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)। একই সাথে, সোশ্যাল মিডিয়া এবং ইভেন্ট এবং কনফারেন্সে এটি শেয়ার করে Dify কে সমর্থন করুন। -> আমরা ম্যান্ডারিন বা ইংরেজি ছাড়া অন্য ভাষায় Dify অনুবাদ করতে সাহায্য করার জন্য অবদানকারীদের খুঁজছি। আপনি যদি সাহায্য করতে আগ্রহী হন, তাহলে আরও তথ্যের জন্য [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) দেখুন এবং আমাদের [ডিসকর্ড কমিউনিটি সার্ভার](https://discord.gg/8Tpq4AcN9c) এর `গ্লোবাল-ইউজারস` চ্যানেলে আমাদের একটি মন্তব্য করুন। +> আমরা ম্যান্ডারিন বা ইংরেজি ছাড়া অন্য ভাষায় Dify অনুবাদ করতে সাহায্য করার জন্য অবদানকারীদের খুঁজছি। আপনি যদি সাহায্য করতে আগ্রহী হন, তাহলে আরও তথ্যের জন্য [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) দেখুন এবং আমাদের [ডিসকর্ড কমিউনিটি সার্ভার](https://discord.gg/8Tpq4AcN9c) এর `গ্লোবাল-ইউজারস` চ্যানেলে আমাদের একটি মন্তব্য করুন। ## কমিউনিটি এবং যোগাযোগ diff --git a/README_CN.md b/README_CN.md index 486a368c09..82149519d3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -233,6 +233,9 @@ docker compose up -d 使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS ## Star History @@ -244,7 +247,7 @@ docker compose up -d 对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。 -> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 +> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 **Contributors** diff --git a/README_DE.md b/README_DE.md index fce52c34c2..2420ac0392 100644 --- a/README_DE.md +++ b/README_DE.md @@ -230,13 +230,17 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung + +Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden + ## Contributing Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. -> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c). +> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c). ## Gemeinschaft & Kontakt diff --git a/README_ES.md b/README_ES.md index 6fd6dfcee8..4fa59dc18f 100644 --- a/README_ES.md +++ b/README_ES.md @@ -230,6 +230,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uso de Azure Devops Pipeline para implementar en AKS + +Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuir @@ -237,7 +241,7 @@ Para aquellos que deseen contribuir con código, consulten nuestra [Guía de con Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias. -> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). +> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). **Contribuidores** diff --git a/README_FR.md b/README_FR.md index b2209fb495..dcbc869620 100644 --- a/README_FR.md +++ b/README_FR.md @@ -228,6 +228,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS + +Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuer @@ -235,7 +239,7 @@ Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribut Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences. -> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). +> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). **Contributeurs** diff --git a/README_JA.md b/README_JA.md index c658225f90..d840fd6419 100644 --- a/README_JA.md +++ b/README_JA.md @@ -227,6 +227,10 @@ docker compose up -d #### Alibaba Cloud Data Management [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます +#### AKSへのデプロイにAzure Devops Pipelineを使用 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ + ## 貢献 @@ -234,7 +238,7 @@ docker compose up -d 同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。 -> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 +> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 **貢献者** diff --git a/README_KL.md b/README_KL.md index bfafcc7407..41c7969e1c 100644 --- a/README_KL.md +++ b/README_KL.md @@ -228,6 +228,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy + ## Contributing @@ -235,7 +239,7 @@ For those who'd like to contribute code, see our [Contribution Guide](https://gi At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. -> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). +> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). **Contributors** diff --git a/README_KR.md b/README_KR.md index 282117e776..d4b31a8928 100644 --- a/README_KR.md +++ b/README_KR.md @@ -222,6 +222,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다 +#### AKS에 배포하기 위해 Azure Devops Pipeline 사용 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포 + ## 기여 @@ -229,7 +233,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다. -> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. +> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. **기여자** diff --git a/README_PT.md b/README_PT.md index 576f6b48f7..94452cb233 100644 --- a/README_PT.md +++ b/README_PT.md @@ -227,13 +227,17 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Usando Azure Devops Pipeline para Implantar no AKS + +Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Contribuindo Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências. -> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). +> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). **Contribuidores** diff --git a/README_SI.md b/README_SI.md index 7ded001d86..d840e9155f 100644 --- a/README_SI.md +++ b/README_SI.md @@ -228,6 +228,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uporaba Azure Devops Pipeline za uvajanje v AKS + +Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Prispevam diff --git a/README_TR.md b/README_TR.md index 6e94e54fa0..470a7570e0 100644 --- a/README_TR.md +++ b/README_TR.md @@ -221,13 +221,17 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın +#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın + ## Katkıda Bulunma Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün. -> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın. +> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın. **Katkıda Bulunanlar** diff --git a/README_TW.md b/README_TW.md index 6e3e22b5c1..18f1d2754a 100644 --- a/README_TW.md +++ b/README_TW.md @@ -233,13 +233,17 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS + ## 貢獻 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。 -> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。 +> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。 ## 社群與聯絡方式 diff --git a/README_VI.md b/README_VI.md index 51314e6de5..2ab6da80fc 100644 --- a/README_VI.md +++ b/README_VI.md @@ -224,6 +224,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS + +Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) + ## Đóng góp @@ -231,7 +235,7 @@ Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị. -> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. +> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. **Người đóng góp** diff --git a/api/.env.example b/api/.env.example index 18f2dbf647..4beabfecea 100644 --- a/api/.env.example +++ b/api/.env.example @@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false # Tidb Vector configuration TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 68b16e48db..ff290ff99d 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -215,7 +215,7 @@ class DatabaseConfig(BaseSettings): class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description="Backend for Celery task results. Options: 'database', 'redis'.", + description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.", default="redis", ) @@ -245,7 +245,12 @@ class CeleryConfig(DatabaseConfig): @computed_field def CELERY_RESULT_BACKEND(self) -> str | None: - return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL + if self.CELERY_BACKEND in ("database", "rabbitmq"): + return f"db+{self.SQLALCHEMY_DATABASE_URI}" + elif self.CELERY_BACKEND == "redis": + return self.CELERY_BROKER_URL + else: + return None @property def BROKER_USE_SSL(self) -> bool: diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py index df8182985d..8c4b333d45 100644 --- a/api/configs/middleware/vdb/elasticsearch_config.py +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -1,12 +1,13 @@ from typing import Optional -from pydantic import Field, PositiveInt +from pydantic import Field, PositiveInt, model_validator from pydantic_settings import BaseSettings class ElasticsearchConfig(BaseSettings): """ - Configuration settings for Elasticsearch + Configuration settings for both self-managed and Elastic Cloud deployments. + Can load from environment variables or .env files. """ ELASTICSEARCH_HOST: Optional[str] = Field( @@ -28,3 +29,50 @@ class ElasticsearchConfig(BaseSettings): description="Password for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) + + # Elastic Cloud (optional) + ELASTICSEARCH_USE_CLOUD: Optional[bool] = Field( + description="Set to True to use Elastic Cloud instead of self-hosted Elasticsearch", default=False + ) + ELASTICSEARCH_CLOUD_URL: Optional[str] = Field( + description="Full URL for Elastic Cloud deployment (e.g., 'https://example.es.region.aws.found.io:443')", + default=None, + ) + ELASTICSEARCH_API_KEY: Optional[str] = Field( + description="API key for authenticating with Elastic Cloud", default=None + ) + + # Common options + ELASTICSEARCH_CA_CERTS: Optional[str] = Field( + description="Path to CA certificate file for SSL verification", default=None + ) + ELASTICSEARCH_VERIFY_CERTS: bool = Field( + description="Whether to verify SSL certificates (default is False)", default=False + ) + ELASTICSEARCH_REQUEST_TIMEOUT: int = Field( + description="Request timeout in milliseconds (default is 100000)", default=100000 + ) + ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = Field( + description="Whether to retry requests on timeout (default is True)", default=True + ) + ELASTICSEARCH_MAX_RETRIES: int = Field( + description="Maximum number of retry attempts (default is 10000)", default=10000 + ) + + @model_validator(mode="after") + def validate_elasticsearch_config(self): + """Validate Elasticsearch configuration based on deployment type.""" + if self.ELASTICSEARCH_USE_CLOUD: + if not self.ELASTICSEARCH_CLOUD_URL: + raise ValueError("ELASTICSEARCH_CLOUD_URL is required when using Elastic Cloud") + if not self.ELASTICSEARCH_API_KEY: + raise ValueError("ELASTICSEARCH_API_KEY is required when using Elastic Cloud") + else: + if not self.ELASTICSEARCH_HOST: + raise ValueError("ELASTICSEARCH_HOST is required for self-hosted Elasticsearch") + if not self.ELASTICSEARCH_USERNAME: + raise ValueError("ELASTICSEARCH_USERNAME is required for self-hosted Elasticsearch") + if not self.ELASTICSEARCH_PASSWORD: + raise ValueError("ELASTICSEARCH_PASSWORD is required for self-hosted Elasticsearch") + + return self diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index c4dcc0d465..1aab01c6e1 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings): description="AccessKey secret for the instance name", default=None, ) + + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field( + description="Whether to normalize full-text search scores to [0, 1]", + default=False, + ) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index c2ba880405..007b1f6d3d 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -100,7 +100,7 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -class AnnotationListApi(Resource): +class AnnotationApi(Resource): @setup_required @login_required @account_initialization_required @@ -123,6 +123,23 @@ class AnnotationListApi(Resource): } return response, 200 + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("annotation") + @marshal_with(annotation_fields) + def post(self, app_id): + if not current_user.is_editor: + raise Forbidden() + + app_id = str(app_id) + parser = reqparse.RequestParser() + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + args = parser.parse_args() + annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) + return annotation + @setup_required @login_required @account_initialization_required @@ -131,8 +148,25 @@ class AnnotationListApi(Resource): raise Forbidden() app_id = str(app_id) - AppAnnotationService.clear_all_annotations(app_id) - return {"result": "success"}, 204 + + # Use request.args.getlist to get annotation_ids array directly + annotation_ids = request.args.getlist("annotation_id") + + # If annotation_ids are provided, handle batch deletion + if annotation_ids: + # Check if any annotation_ids contain empty strings or invalid values + if not all(annotation_id.strip() for annotation_id in annotation_ids if annotation_id): + return { + "code": "bad_request", + "message": "annotation_ids are required if the parameter is provided.", + }, 400 + + result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids) + return result, 204 + # If no annotation_ids are provided, handle clearing all annotations + else: + AppAnnotationService.clear_all_annotations(app_id) + return {"result": "success"}, 204 class AnnotationExportApi(Resource): @@ -149,25 +183,6 @@ class AnnotationExportApi(Resource): return response, 200 -class AnnotationCreateApi(Resource): - @setup_required - @login_required - @account_initialization_required - @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) - def post(self, app_id): - if not current_user.is_editor: - raise Forbidden() - - app_id = str(app_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() - annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) - return annotation - - class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @@ -276,7 +291,7 @@ api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply api.add_resource( AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" ) -api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationApi, "/apps//annotations") api.add_resource(AnnotationExportApi, "/apps//annotations/export") api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index b1a83aa371..65f76fb402 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -22,8 +22,8 @@ class DatasetMetadataCreateApi(Resource): @marshal_with(dataset_metadata_fields) def post(self, dataset_id): parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=True, location="json") - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() metadata_args = MetadataArgs(**args) @@ -56,7 +56,7 @@ class DatasetMetadataApi(Resource): @marshal_with(dataset_metadata_fields) def patch(self, dataset_id, metadata_id): parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id_str = str(dataset_id) @@ -127,7 +127,7 @@ class DocumentMetadataEditApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) parser = reqparse.RequestParser() - parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") + parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") args = parser.parse_args() metadata_args = MetadataOperationData(**args) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 7762672494..edc66cc5e9 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -47,6 +47,9 @@ class CompletionApi(Resource): parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id streaming = args["response_mode"] == "streaming" diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 36a7905572..79c860e6b8 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,7 +1,9 @@ +import json + from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range from sqlalchemy.orm import Session -from werkzeug.exceptions import NotFound +from werkzeug.exceptions import BadRequest, NotFound import services from controllers.service_api import api @@ -15,6 +17,7 @@ from fields.conversation_fields import ( simple_conversation_fields, ) from fields.conversation_variable_fields import ( + conversation_variable_fields, conversation_variable_infinite_scroll_pagination_fields, ) from libs.helper import uuid_value @@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource): raise NotFound("Conversation Not Exists.") +class ConversationVariableDetailApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @marshal_with(conversation_variable_fields) + def put(self, app_model: App, end_user: EndUser, c_id, variable_id): + """Update a conversation variable's value""" + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + conversation_id = str(c_id) + variable_id = str(variable_id) + + parser = reqparse.RequestParser() + parser.add_argument("value", required=True, location="json") + args = parser.parse_args() + + try: + return ConversationService.update_conversation_variable( + app_model, conversation_id, variable_id, end_user, json.loads(args["value"]) + ) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationVariableNotExistsError: + raise NotFound("Conversation Variable Not Exists.") + except services.errors.conversation.ConversationVariableTypeMismatchError as e: + raise BadRequest(str(e)) + + api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") api.add_resource(ConversationApi, "/conversations") api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables") +api.add_resource( + ConversationVariableDetailApi, + "/conversations//variables/", + endpoint="conversation_variable_detail", + methods=["PUT"], +) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 1968696ee5..6ba818c5fc 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -17,8 +17,8 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=True, location="json") - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() metadata_args = MetadataArgs(**args) @@ -43,7 +43,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id_str = str(dataset_id) @@ -101,7 +101,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) parser = reqparse.RequestParser() - parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") + parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") args = parser.parse_args() metadata_args = MetadataOperationData(**args) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 75bd2f677a..0df0aa59b2 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -148,6 +148,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index af0e38f7ef..06050619e9 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session, sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( TraceClient, convert_datetime_to_nanoseconds, + convert_string_to_id, convert_to_span_id, convert_to_trace_id, generate_span_id, @@ -101,8 +102,9 @@ class AliyunDataTrace(BaseTraceInstance): raise ValueError(f"Aliyun get run url failed: {str(e)}") def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id) + trace_id = convert_to_trace_id(trace_info.workflow_run_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") self.add_workflow_span(trace_id, workflow_span_id, trace_info) @@ -130,6 +132,9 @@ class AliyunDataTrace(BaseTraceInstance): status = Status(StatusCode.ERROR, trace_info.error) trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + message_span_id = convert_to_span_id(message_id, "message") message_span = SpanData( trace_id=trace_id, @@ -186,9 +191,13 @@ class AliyunDataTrace(BaseTraceInstance): return message_id = trace_info.message_id + trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + documents_data = extract_retrieval_documents(trace_info.documents) dataset_retrieval_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name="dataset_retrieval", @@ -214,8 +223,12 @@ class AliyunDataTrace(BaseTraceInstance): if trace_info.error: status = Status(StatusCode.ERROR, trace_info.error) + trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + tool_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name=trace_info.tool_name, @@ -451,8 +464,13 @@ class AliyunDataTrace(BaseTraceInstance): status: Status = Status(StatusCode.OK) if trace_info.error: status = Status(StatusCode.ERROR, trace_info.error) + + trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + suggested_question_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=convert_to_span_id(message_id, "suggested_question"), name="suggested_question", diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 934ce95a64..bd19c8a503 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -181,15 +181,21 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int: raise ValueError(f"Invalid UUID input: {e}") +def convert_string_to_id(string: Optional[str]) -> int: + if not string: + return generate_span_id() + hash_bytes = hashlib.sha256(string.encode("utf-8")).digest() + id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) + return id + + def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: try: uuid_obj = uuid.UUID(uuid_v4) except Exception as e: raise ValueError(f"Invalid UUID input: {e}") combined_key = f"{uuid_obj.hex}-{span_type}" - hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest() - span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) - return span_id + return convert_string_to_id(combined_key) def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]: diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index f252a022d8..e7c90c1229 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -4,6 +4,7 @@ import logging import os from datetime import datetime, timedelta from typing import Any, Optional, Union, cast +from urllib.parse import urlparse from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -40,8 +41,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra try: # Choose the appropriate exporter based on config type exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter] + + # Inspect the provided endpoint to determine its structure + parsed = urlparse(arize_phoenix_config.endpoint) + base_endpoint = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + if isinstance(arize_phoenix_config, ArizeConfig): - arize_endpoint = f"{arize_phoenix_config.endpoint}/v1" + arize_endpoint = f"{base_endpoint}/v1" arize_headers = { "api_key": arize_phoenix_config.api_key or "", "space_id": arize_phoenix_config.space_id or "", @@ -53,7 +60,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra timeout=30, ) else: - phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces" + phoenix_endpoint = f"{base_endpoint}{path}/v1/traces" phoenix_headers = { "api_key": arize_phoenix_config.api_key or "", "authorization": f"Bearer {arize_phoenix_config.api_key or ''}", @@ -91,16 +98,21 @@ def datetime_to_nanos(dt: Optional[datetime]) -> int: return int(dt.timestamp() * 1_000_000_000) -def uuid_to_trace_id(string: Optional[str]) -> int: - """Convert UUID string to a valid trace ID (16-byte integer).""" +def string_to_trace_id128(string: Optional[str]) -> int: + """ + Convert any input string into a stable 128-bit integer trace ID. + + This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest. + It's suitable for generating consistent, unique identifiers from strings. + """ if string is None: string = "" hash_object = hashlib.sha256(string.encode()) - # Take the first 16 bytes (128 bits) of the hash + # Take the first 16 bytes (128 bits) of the hash digest digest = hash_object.digest()[:16] - # Convert to integer (128 bits) + # Convert to a 128-bit integer return int.from_bytes(digest, byteorder="big") @@ -153,8 +165,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id) + trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -310,7 +321,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id) message_span_id = RandomIdGenerator().generate_span_id() span_context = SpanContext( trace_id=trace_id, @@ -406,7 +417,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -468,7 +479,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -521,7 +532,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -568,7 +579,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False), } - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) tool_span_id = RandomIdGenerator().generate_span_id() logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id) @@ -629,7 +640,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 626782cee5..851a77fbc1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -87,7 +87,7 @@ class PhoenixConfig(BaseTracingConfig): @field_validator("endpoint") @classmethod def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com") + return validate_url_with_path(v, "https://app.phoenix.arize.com") class LangfuseConfig(BaseTracingConfig): diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 151fa2aaf4..3bad5c92fb 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -14,6 +14,7 @@ class BaseTraceInfo(BaseModel): start_time: Optional[datetime] = None end_time: Optional[datetime] = None metadata: dict[str, Any] + trace_id: Optional[str] = None @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index d356e735ee..3a03d9f4fe 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -67,14 +67,13 @@ class LangFuseDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id if trace_info.message_id: - trace_id = external_trace_id or trace_info.message_id + trace_id = trace_info.trace_id or trace_info.message_id name = TraceTaskName.MESSAGE_TRACE.value trace_data = LangfuseTrace( id=trace_id, @@ -250,8 +249,10 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = end_user_data.session_id metadata["user_id"] = user_id + trace_id = trace_info.trace_id or message_id + trace_data = LangfuseTrace( - id=message_id, + id=trace_id, user_id=user_id, name=TraceTaskName.MESSAGE_TRACE.value, input={ @@ -285,7 +286,7 @@ class LangFuseDataTrace(BaseTraceInstance): langfuse_generation_data = LangfuseGeneration( name="llm", - trace_id=message_id, + trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, model=message_data.model_id, @@ -311,7 +312,7 @@ class LangFuseDataTrace(BaseTraceInstance): "preset_response": trace_info.preset_response, "inputs": trace_info.inputs, }, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.created_at, metadata=trace_info.metadata, @@ -334,7 +335,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, input=trace_info.inputs, output=str(trace_info.suggested_question), - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, @@ -352,7 +353,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, input=trace_info.inputs, output={"documents": trace_info.documents}, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, metadata=trace_info.metadata, @@ -365,7 +366,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=trace_info.tool_name, input=trace_info.tool_inputs, output=trace_info.tool_outputs, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index fb3f6ecf0d..f9e5128e89 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -65,8 +65,7 @@ class LangSmithDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() message_dotted_order = ( @@ -290,7 +289,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, parent_run_id=None, ) @@ -319,7 +318,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, id=str(uuid.uuid4()), ) @@ -351,7 +350,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -381,7 +380,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -410,7 +409,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -440,7 +439,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error=trace_info.error or "", ) @@ -465,7 +464,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 1e52f28350..dd6a424ddb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -96,8 +96,7 @@ class OpikDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - dify_trace_id = external_trace_id or trace_info.workflow_run_id + dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) workflow_metadata = wrap_metadata( trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id @@ -105,7 +104,7 @@ class OpikDataTrace(BaseTraceInstance): root_span_id = None if trace_info.message_id: - dify_trace_id = external_trace_id or trace_info.message_id + dify_trace_id = trace_info.trace_id or trace_info.message_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) trace_data = { @@ -276,7 +275,7 @@ class OpikDataTrace(BaseTraceInstance): return metadata = trace_info.metadata - message_id = trace_info.message_id + dify_trace_id = trace_info.trace_id or trace_info.message_id user_id = message_data.from_account_id metadata["user_id"] = user_id @@ -291,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["end_user_id"] = end_user_id trace_data = { - "id": prepare_opik_uuid(trace_info.start_time, message_id), + "id": prepare_opik_uuid(trace_info.start_time, dify_trace_id), "name": TraceTaskName.MESSAGE_TRACE.value, "start_time": trace_info.start_time, "end_time": trace_info.end_time, @@ -330,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or trace_info.message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.MODERATION_TRACE.value, "type": "tool", "start_time": start_time, @@ -356,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value, "type": "tool", "start_time": start_time, @@ -376,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or trace_info.message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value, "type": "tool", "start_time": start_time, @@ -391,7 +390,7 @@ class OpikDataTrace(BaseTraceInstance): def tool_trace(self, trace_info: ToolTraceInfo): span_data = { - "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), "name": trace_info.tool_name, "type": "tool", "start_time": trace_info.start_time, @@ -406,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): trace_data = { - "id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.GENERATE_NAME_TRACE.value, "start_time": trace_info.start_time, "end_time": trace_info.end_time, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 91cdc937a6..7eb5da7e3a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -322,7 +322,7 @@ class OpsTraceManager: :return: """ # auth check - if enabled == True: + if enabled: try: provider_config_map[tracing_provider] except KeyError: @@ -422,8 +422,11 @@ class TraceTask: self.timer = timer self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.app_id = None - + self.trace_id = None self.kwargs = kwargs + external_trace_id = kwargs.get("external_trace_id") + if external_trace_id: + self.trace_id = external_trace_id def execute(self): return self.preprocess() @@ -520,11 +523,8 @@ class TraceTask: "app_id": workflow_run.app_id, } - external_trace_id = self.kwargs.get("external_trace_id") - if external_trace_id: - metadata["external_trace_id"] = external_trace_id - workflow_trace_info = WorkflowTraceInfo( + trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), conversation_id=conversation_id, workflow_id=workflow_id, @@ -584,6 +584,7 @@ class TraceTask: message_tokens = message_data.message_tokens message_trace_info = MessageTraceInfo( + trace_id=self.trace_id, message_id=message_id, message_data=message_data.to_dict(), conversation_model=conversation_mode, @@ -627,6 +628,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( + trace_id=self.trace_id, message_id=workflow_app_log_id or message_id, inputs=inputs, message_data=message_data.to_dict(), @@ -667,6 +669,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( + trace_id=self.trace_id, message_id=workflow_app_log_id or message_id, message_data=message_data.to_dict(), inputs=message_data.message, @@ -708,6 +711,7 @@ class TraceTask: } dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( + trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, documents=[doc.model_dump() for doc in documents] if documents else [], @@ -772,6 +776,7 @@ class TraceTask: ) tool_trace_info = ToolTraceInfo( + trace_id=self.trace_id, message_id=message_id, message_data=message_data.to_dict(), tool_name=tool_name, @@ -807,6 +812,7 @@ class TraceTask: } generate_name_trace_info = GenerateNameTraceInfo( + trace_id=self.trace_id, conversation_id=conversation_id, inputs=inputs, outputs=generate_conversation_name, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 470601b17a..8089860481 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -87,8 +87,7 @@ class WeaveDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() @@ -245,8 +244,12 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time attributes["end_time"] = trace_info.end_time attributes["tags"] = ["message", str(trace_info.conversation_mode)] + + trace_id = trace_info.trace_id or message_id + attributes["trace_id"] = trace_id + message_run = WeaveTraceModel( - id=message_id, + id=trace_id, op=str(TraceTaskName.MESSAGE_TRACE.value), input_tokens=trace_info.message_tokens, output_tokens=trace_info.answer_tokens, @@ -274,7 +277,7 @@ class WeaveDataTrace(BaseTraceInstance): ) self.start_call( llm_run, - parent_run_id=message_id, + parent_run_id=trace_id, ) self.finish_call(llm_run) self.finish_call(message_run) @@ -289,6 +292,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + moderation_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.MODERATION_TRACE.value), @@ -303,7 +309,7 @@ class WeaveDataTrace(BaseTraceInstance): exception=getattr(trace_info, "error", None), file_list=[], ) - self.start_call(moderation_run, parent_run_id=trace_info.message_id) + self.start_call(moderation_run, parent_run_id=trace_id) self.finish_call(moderation_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): @@ -316,6 +322,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = (trace_info.start_time or message_data.created_at,) attributes["end_time"] = (trace_info.end_time or message_data.updated_at,) + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + suggested_question_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value), @@ -326,7 +335,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(suggested_question_run, parent_run_id=trace_info.message_id) + self.start_call(suggested_question_run, parent_run_id=trace_id) self.finish_call(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): @@ -338,6 +347,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value), @@ -348,7 +360,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id) + self.start_call(dataset_retrieval_run, parent_run_id=trace_id) self.finish_call(dataset_retrieval_run) def tool_trace(self, trace_info: ToolTraceInfo): @@ -357,6 +369,11 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time attributes["end_time"] = trace_info.end_time + message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) + message_id = message_id or None + trace_id = trace_info.trace_id or message_id + attributes["trace_id"] = trace_id + tool_run = WeaveTraceModel( id=str(uuid.uuid4()), op=trace_info.tool_name, @@ -366,9 +383,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes=attributes, exception=trace_info.error, ) - message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) - message_id = message_id or None - self.start_call(tool_run, parent_run_id=message_id) + self.start_call(tool_run, parent_run_id=trace_id) self.finish_call(tool_run) def generate_name_trace(self, trace_info: GenerateNameTraceInfo): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 832485b236..9dea050dc3 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -22,22 +22,50 @@ logger = logging.getLogger(__name__) class ElasticSearchConfig(BaseModel): - host: str - port: int - username: str - password: str + # Regular Elasticsearch config + host: Optional[str] = None + port: Optional[int] = None + username: Optional[str] = None + password: Optional[str] = None + + # Elastic Cloud specific config + cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud + api_key: Optional[str] = None + + # Common config + use_cloud: bool = False + ca_certs: Optional[str] = None + verify_certs: bool = False + request_timeout: int = 100000 + retry_on_timeout: bool = True + max_retries: int = 10000 @model_validator(mode="before") @classmethod def validate_config(cls, values: dict) -> dict: - if not values["host"]: - raise ValueError("config HOST is required") - if not values["port"]: - raise ValueError("config PORT is required") - if not values["username"]: - raise ValueError("config USERNAME is required") - if not values["password"]: - raise ValueError("config PASSWORD is required") + use_cloud = values.get("use_cloud", False) + cloud_url = values.get("cloud_url") + + if use_cloud: + # Cloud configuration validation - requires cloud_url and api_key + if not cloud_url: + raise ValueError("cloud_url is required for Elastic Cloud") + + api_key = values.get("api_key") + if not api_key: + raise ValueError("api_key is required for Elastic Cloud") + + else: + # Regular Elasticsearch validation + if not values.get("host"): + raise ValueError("config HOST is required for regular Elasticsearch") + if not values.get("port"): + raise ValueError("config PORT is required for regular Elasticsearch") + if not values.get("username"): + raise ValueError("config USERNAME is required for regular Elasticsearch") + if not values.get("password"): + raise ValueError("config PASSWORD is required for regular Elasticsearch") + return values @@ -50,21 +78,69 @@ class ElasticSearchVector(BaseVector): self._attributes = attributes def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: + """ + Initialize Elasticsearch client for both regular Elasticsearch and Elastic Cloud. + """ try: - parsed_url = urlparse(config.host) - if parsed_url.scheme in {"http", "https"}: - hosts = f"{config.host}:{config.port}" + # Check if using Elastic Cloud + client_config: dict[str, Any] + if config.use_cloud and config.cloud_url: + client_config = { + "request_timeout": config.request_timeout, + "retry_on_timeout": config.retry_on_timeout, + "max_retries": config.max_retries, + "verify_certs": config.verify_certs, + } + + # Parse cloud URL and configure hosts + parsed_url = urlparse(config.cloud_url) + host = f"{parsed_url.scheme}://{parsed_url.hostname}" + if parsed_url.port: + host += f":{parsed_url.port}" + + client_config["hosts"] = [host] + + # API key authentication for cloud + client_config["api_key"] = config.api_key + + # SSL settings + if config.ca_certs: + client_config["ca_certs"] = config.ca_certs + else: - hosts = f"http://{config.host}:{config.port}" - client = Elasticsearch( - hosts=hosts, - basic_auth=(config.username, config.password), - request_timeout=100000, - retry_on_timeout=True, - max_retries=10000, - ) - except requests.exceptions.ConnectionError: - raise ConnectionError("Vector database connection error") + # Regular Elasticsearch configuration + parsed_url = urlparse(config.host or "") + if parsed_url.scheme in {"http", "https"}: + hosts = f"{config.host}:{config.port}" + use_https = parsed_url.scheme == "https" + else: + hosts = f"http://{config.host}:{config.port}" + use_https = False + + client_config = { + "hosts": [hosts], + "basic_auth": (config.username, config.password), + "request_timeout": config.request_timeout, + "retry_on_timeout": config.retry_on_timeout, + "max_retries": config.max_retries, + } + + # Only add SSL settings if using HTTPS + if use_https: + client_config["verify_certs"] = config.verify_certs + if config.ca_certs: + client_config["ca_certs"] = config.ca_certs + + client = Elasticsearch(**client_config) + + # Test connection + if not client.ping(): + raise ConnectionError("Failed to connect to Elasticsearch") + + except requests.exceptions.ConnectionError as e: + raise ConnectionError(f"Vector database connection error: {str(e)}") + except Exception as e: + raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") return client @@ -209,7 +285,11 @@ class ElasticSearchVector(BaseVector): }, } } + self._client.indices.create(index=self._collection_name, mappings=mappings) + logger.info("Created index %s with dimension %s", self._collection_name, dim) + else: + logger.info("Collection %s already exists.", self._collection_name) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -225,13 +305,51 @@ class ElasticSearchVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) config = current_app.config + + # Check if ELASTICSEARCH_USE_CLOUD is explicitly set to false (boolean) + use_cloud_env = config.get("ELASTICSEARCH_USE_CLOUD", False) + + if use_cloud_env is False: + # Use regular Elasticsearch with config values + config_dict = { + "use_cloud": False, + "host": config.get("ELASTICSEARCH_HOST", "elasticsearch"), + "port": config.get("ELASTICSEARCH_PORT", 9200), + "username": config.get("ELASTICSEARCH_USERNAME", "elastic"), + "password": config.get("ELASTICSEARCH_PASSWORD", "elastic"), + } + else: + # Check for cloud configuration + cloud_url = config.get("ELASTICSEARCH_CLOUD_URL") + if cloud_url: + config_dict = { + "use_cloud": True, + "cloud_url": cloud_url, + "api_key": config.get("ELASTICSEARCH_API_KEY"), + } + else: + # Fallback to regular Elasticsearch + config_dict = { + "use_cloud": False, + "host": config.get("ELASTICSEARCH_HOST", "localhost"), + "port": config.get("ELASTICSEARCH_PORT", 9200), + "username": config.get("ELASTICSEARCH_USERNAME", "elastic"), + "password": config.get("ELASTICSEARCH_PASSWORD", ""), + } + + # Common configuration + config_dict.update( + { + "ca_certs": str(config.get("ELASTICSEARCH_CA_CERTS")) if config.get("ELASTICSEARCH_CA_CERTS") else None, + "verify_certs": bool(config.get("ELASTICSEARCH_VERIFY_CERTS", False)), + "request_timeout": int(config.get("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)), + "retry_on_timeout": bool(config.get("ELASTICSEARCH_RETRY_ON_TIMEOUT", True)), + "max_retries": int(config.get("ELASTICSEARCH_MAX_RETRIES", 10000)), + } + ) + return ElasticSearchVector( index_name=collection_name, - config=ElasticSearchConfig( - host=config.get("ELASTICSEARCH_HOST", "localhost"), - port=config.get("ELASTICSEARCH_PORT", 9200), - username=config.get("ELASTICSEARCH_USERNAME", ""), - password=config.get("ELASTICSEARCH_PASSWORD", ""), - ), + config=ElasticSearchConfig(**config_dict), attributes=[], ) diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 784e27fc7f..91d667ff2c 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -1,5 +1,6 @@ import json import logging +import math from typing import Any, Optional import tablestore # type: ignore @@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel): access_key_secret: Optional[str] = None instance_name: Optional[str] = None endpoint: Optional[str] = None + normalize_full_text_bm25_score: Optional[bool] = False @model_validator(mode="before") @classmethod @@ -47,6 +49,7 @@ class TableStoreVector(BaseVector): config.access_key_secret, config.instance_name, ) + self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score self._table_name = f"{collection_name}" self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" @@ -131,8 +134,8 @@ class TableStoreVector(BaseVector): filtered_list = None if document_ids_filter: filtered_list = ["document_id=" + item for item in document_ids_filter] - - return self._search_by_full_text(query, filtered_list, top_k) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._search_by_full_text(query, filtered_list, top_k, score_threshold) def delete(self) -> None: self._delete_table_if_exist() @@ -318,7 +321,19 @@ class TableStoreVector(BaseVector): documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: + @staticmethod + def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float: + """ + Args: + score: BM25 search score. + k: decay factor, the larger the k, the steeper the low score end + """ + normalized_score = 1 - math.exp(-k * score) + return max(0.0, min(1.0, normalized_score)) + + def _search_by_full_text( + self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float + ) -> list[Document]: bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) @@ -339,15 +354,27 @@ class TableStoreVector(BaseVector): documents = [] for search_hit in search_response.search_hits: + score = None + if self._normalize_full_text_bm25_score: + score = self._normalize_score_exp_decay(search_hit.score) + + # skip when score is below threshold and use normalize score + if score and score <= score_threshold: + continue + ots_column_map = {} for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - vector_str = ots_column_map.get(Field.VECTOR.value) metadata_str = ots_column_map.get(Field.METADATA_KEY.value) - vector = json.loads(vector_str) if vector_str else None metadata = json.loads(metadata_str) if metadata_str else {} + vector_str = ots_column_map.get(Field.VECTOR.value) + vector = json.loads(vector_str) if vector_str else None + + if score: + metadata["score"] = score + documents.append( Document( page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", @@ -355,6 +382,8 @@ class TableStoreVector(BaseVector): metadata=metadata, ) ) + if self._normalize_full_text_bm25_score: + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents @@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory): instance_name=dify_config.TABLESTORE_INSTANCE_NAME, access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID, access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET, + normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE, ), ) diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index 6ef932ad22..1f054bccdb 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 875626eb34..17f4d1af2d 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,5 +1,6 @@ import json import logging +import operator from typing import Any, Optional, cast import requests @@ -130,13 +131,15 @@ class NotionExtractor(BaseExtractor): data[property_name] = value row_dict = {k: v for k, v in data.items() if v} row_content = "" - for key, value in row_dict.items(): + for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) row_content = row_content + f"{key}:{value_content}\n" else: row_content = row_content + f"{key}:{value}\n" + if "url" in result: + row_content = row_content + f"Row Page URL:{result.get('url', '')}\n" database_content.append(row_content) has_more = response_data.get("has_more", False) diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index f9b776b3b9..91316b859a 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool): target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore if not target_time: yield self.create_text_message( - f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" + f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}" ) return diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 962b9f7a81..db6b84082f 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -142,7 +142,7 @@ class WorkflowTool(Tool): if not version: workflow = ( db.session.query(Workflow) - .where(Workflow.app_id == app_id, Workflow.version != "draft") + .where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT) .order_by(Workflow.created_at.desc()) .first() ) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index f3061f7d96..23512c8ce4 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -597,7 +597,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str: for i in range(1, len(raw_results)): spk, txt = raw_results[i] - if spk == None: + if spk is None: merged_results.append((None, current_text)) continue diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 8ac1ae8526..2106369bd6 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -265,9 +265,9 @@ class Executor: if not authorization.config.header: authorization.config.header = "Authorization" - if self.auth.config.type == "bearer": + if self.auth.config.type == "bearer" and authorization.config.api_key: headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.auth.config.type == "basic": + elif self.auth.config.type == "basic" and authorization.config.api_key: credentials = authorization.config.api_key if ":" in credentials: encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") @@ -277,6 +277,22 @@ class Executor: elif self.auth.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key or "" + # Handle Content-Type for multipart/form-data requests + # Fix for issue #22880: Missing boundary when using multipart/form-data + body = self.node_data.body + if body and body.type == "form-data": + # For multipart/form-data with files, let httpx handle the boundary automatically + # by not setting Content-Type header when files are present + if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files): + # Only set Content-Type when there are no actual files + # This ensures httpx generates the correct boundary + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = "multipart/form-data" + elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: + # Set Content-Type for other body types + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + return headers def _validate_and_parse_response(self, response: httpx.Response) -> Response: @@ -384,15 +400,24 @@ class Executor: # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. # This prevents logging meaningless placeholder entries. if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for key, (filename, content, mime_type) in self.files: + for file_entry in self.files: + # file_entry should be (key, (filename, content, mime_type)), but handle edge cases + if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2: + continue # skip malformed entries + key = file_entry[0] + content = file_entry[1][1] body_string += f"--{boundary}\r\n" body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content - try: - body_string += content.decode("utf-8") - except UnicodeDecodeError: - # fix: decode binary content - pass + # decode content safely + if isinstance(content, bytes): + try: + body_string += content.decode("utf-8") + except UnicodeDecodeError: + body_string += content.decode("utf-8", errors="replace") + elif isinstance(content, str): + body_string += content + else: + body_string += f"[Unsupported content type: {type(content).__name__}]" body_string += "\r\n" body_string += f"--{boundary}--\r\n" elif self.node_data.body: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index f1767bdf9e..b71271abeb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -74,6 +74,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e041e217ca..7303b68501 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -602,6 +602,28 @@ class KnowledgeRetrievalNode(BaseNode): **{key: metadata_name, key_value: f"%{value}"} ) ) + case "in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) + case "not in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) case "=" | "is": if isinstance(value, str): filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 90a0397b67..dfc2a0000b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,7 +3,7 @@ import io import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.entities.model_entities import ( - AIModelEntity, ModelFeature, ModelPropertyKey, ModelType, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -1006,21 +1004,6 @@ class LLMNode(BaseNode): ) return saved_file - def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: - """ - Fetch model schema - """ - model_name = self._node_data.model.name - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name - ) - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_credentials = model_instance.credentials - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - return model_schema - @staticmethod def fetch_structured_output_schema( *, diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 512a9cb608..b2bcee5dcd 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,4 +1,6 @@ import mimetypes +import os +import urllib.parse import uuid from collections.abc import Callable, Mapping, Sequence from typing import Any, cast @@ -240,16 +242,21 @@ def _build_from_remote_url( def _get_remote_file_info(url: str): file_size = -1 - filename = url.split("/")[-1].split("?")[0] or "unknown_file" - mime_type = mimetypes.guess_type(filename)[0] or "" + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + + # Initialize mime_type from filename as fallback + mime_type, _ = mimetypes.guess_type(filename) resp = ssrf_proxy.head(url, follow_redirects=True) resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) + # Re-guess mime_type from updated filename + mime_type, _ = mimetypes.guess_type(filename) file_size = int(resp.headers.get("Content-Length", file_size)) - mime_type = mime_type or str(resp.headers.get("Content-Type", "")) return mime_type, filename, file_size diff --git a/api/models/account.py b/api/models/account.py index d63c5d7fb5..3437055893 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, cast from flask_login import UserMixin # type: ignore -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -86,23 +86,21 @@ class Account(UserMixin, Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) - email: Mapped[str] = mapped_column(db.String(255)) - password: Mapped[Optional[str]] = mapped_column(db.String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(db.String(255)) - avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(db.String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(db.String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column( - db.DateTime, server_default=func.current_timestamp(), nullable=False - ) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + name: Mapped[str] = mapped_column(String(255)) + email: Mapped[str] = mapped_column(String(255)) + password: Mapped[Optional[str]] = mapped_column(String(255)) + password_salt: Mapped[Optional[str]] = mapped_column(String(255)) + avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + interface_language: Mapped[Optional[str]] = mapped_column(String(255)) + interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + timezone: Mapped[Optional[str]] = mapped_column(String(255)) + last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + status: Mapped[str] = mapped_column(String(16), server_default=db.text("'active'::character varying")) + initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): @@ -200,13 +198,13 @@ class Tenant(Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) encrypt_public_key = db.Column(db.Text) - plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + plan: Mapped[str] = mapped_column(String(255), server_default=db.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) custom_config: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -237,10 +235,10 @@ class TenantAccountJoin(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - role: Mapped[str] = mapped_column(db.String(16), server_default="normal") + role: Mapped[str] = mapped_column(String(16), server_default="normal") invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class AccountIntegrate(Base): @@ -253,11 +251,11 @@ class AccountIntegrate(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id: Mapped[str] = mapped_column(StringUUID) - provider: Mapped[str] = mapped_column(db.String(16)) - open_id: Mapped[str] = mapped_column(db.String(255)) - encrypted_token: Mapped[str] = mapped_column(db.String(255)) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + provider: Mapped[str] = mapped_column(String(16)) + open_id: Mapped[str] = mapped_column(String(255)) + encrypted_token: Mapped[str] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class InvitationCode(Base): @@ -269,14 +267,14 @@ class InvitationCode(Base): ) id: Mapped[int] = mapped_column(db.Integer) - batch: Mapped[str] = mapped_column(db.String(255)) - code: Mapped[str] = mapped_column(db.String(32)) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + batch: Mapped[str] = mapped_column(String(255)) + code: Mapped[str] = mapped_column(String(32)) + status: Mapped[str] = mapped_column(String(16), server_default=db.text("'unused'::character varying")) + used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) + deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -298,10 +296,8 @@ class TenantPluginPermission(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column( - db.String(16), nullable=False, server_default="everyone" - ) - debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") + debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") class TenantPluginAutoUpgradeStrategy(Base): @@ -323,14 +319,10 @@ class TenantPluginAutoUpgradeStrategy(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only") + strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") + exclude_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + include_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 3cef5a0fb2..ac9eda6829 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum +from datetime import datetime -from sqlalchemy import func -from sqlalchemy.orm import mapped_column +from sqlalchemy import DateTime, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -24,7 +25,7 @@ class APIBasedExtension(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - api_endpoint = mapped_column(db.String(255), nullable=False) - api_key = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + name: Mapped[str] = mapped_column(String(255), nullable=False) + api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False) + api_key = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 01372f8bf6..4d41d0c8b3 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,7 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -48,22 +48,22 @@ class Dataset(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) description = mapped_column(db.Text, nullable=True) - provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying")) - permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying")) - data_source_type = mapped_column(db.String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255)) + provider: Mapped[str] = mapped_column(String(255), server_default=db.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(String(255), server_default=db.text("'only_me'::character varying")) + data_source_type = mapped_column(String(255)) + indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) index_struct = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column - embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(String(255), nullable=True) # TODO: mapped_column + embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + built_in_field_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def dataset_keyword_table(self): @@ -268,10 +268,10 @@ class DatasetProcessRule(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + mode = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -313,61 +313,59 @@ class Document(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) - data_source_type = mapped_column(db.String(255), nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) + data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) data_source_info = mapped_column(db.Text, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) - batch = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_from = mapped_column(db.String(255), nullable=False) + batch: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at = mapped_column(db.DateTime, nullable=True) + processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # parsing file_id = mapped_column(db.Text, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - parsing_completed_at = mapped_column(db.DateTime, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at = mapped_column(db.DateTime, nullable=True) + cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at = mapped_column(db.DateTime, nullable=True) + splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # indexing - tokens = mapped_column(db.Integer, nullable=True) - indexing_latency = mapped_column(db.Float, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + indexing_latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # pause - is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + is_paused: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at = mapped_column(db.DateTime, nullable=True) + paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # error error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column( - db.String(255), nullable=False, server_default=db.text("'waiting'::character varying") - ) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + indexing_status = mapped_column(String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - archived_reason = mapped_column(db.String(255), nullable=True) + archived: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at = mapped_column(db.DateTime, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - doc_type = mapped_column(db.String(40), nullable=True) + archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) - doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) - doc_language = mapped_column(db.String(255), nullable=True) + doc_form = mapped_column(String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_language = mapped_column(String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -524,7 +522,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.upload_date, "type": "time", - "value": self.created_at.timestamp(), + "value": str(self.created_at.timestamp()), } ) built_in_fields.append( @@ -532,7 +530,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.last_update_date, "type": "time", - "value": self.updated_at.timestamp(), + "value": str(self.updated_at.timestamp()), } ) built_in_fields.append( @@ -667,23 +665,23 @@ class DocumentSegment(Base): # indexing fields keywords = mapped_column(db.JSON, nullable=True) - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) # basic fields - hit_count = mapped_column(db.Integer, nullable=False, default=0) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -808,19 +806,23 @@ class ChildChunk(Base): dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) segment_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) content = mapped_column(db.Text, nullable=False) - word_count = mapped_column(db.Integer, nullable=False) + word_count: Mapped[int] = mapped_column(db.Integer, nullable=False) # indexing fields - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) - type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) + type = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) error = mapped_column(db.Text, nullable=True) @property @@ -846,7 +848,7 @@ class AppDatasetJoin(Base): id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def app(self): @@ -863,11 +865,11 @@ class DatasetQuery(Base): id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) content = mapped_column(db.Text, nullable=False) - source = mapped_column(db.String(255), nullable=False) + source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id = mapped_column(StringUUID, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) class DatasetKeywordTable(Base): @@ -881,7 +883,7 @@ class DatasetKeywordTable(Base): dataset_id = mapped_column(StringUUID, nullable=False, unique=True) keyword_table = mapped_column(db.Text, nullable=False) data_source_type = mapped_column( - db.String(255), nullable=False, server_default=db.text("'database'::character varying") + String(255), nullable=False, server_default=db.text("'database'::character varying") ) @property @@ -925,12 +927,12 @@ class Embedding(Base): id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) model_name = mapped_column( - db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") ) - hash = mapped_column(db.String(64), nullable=False) + hash = mapped_column(String(64), nullable=False) embedding = mapped_column(db.LargeBinary, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = mapped_column(String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -947,11 +949,11 @@ class DatasetCollectionBinding(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - provider_name = mapped_column(db.String(255), nullable=False) - model_name = mapped_column(db.String(255), nullable=False) - type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) - collection_name = mapped_column(db.String(64), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + type = mapped_column(String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + collection_name = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(Base): @@ -965,13 +967,13 @@ class TidbAuthBinding(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - cluster_id = mapped_column(db.String(255), nullable=False) - cluster_name = mapped_column(db.String(255), nullable=False) - active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING")) - account = mapped_column(db.String(255), nullable=False) - password = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) + cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) + active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("CREATING")) + account: Mapped[str] = mapped_column(String(255), nullable=False) + password: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(Base): @@ -982,8 +984,8 @@ class Whitelist(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - category = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(Base): @@ -999,8 +1001,8 @@ class DatasetPermission(Base): dataset_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + has_permission: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): @@ -1012,14 +1014,14 @@ class ExternalKnowledgeApis(Base): ) id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) settings = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -1072,9 +1074,9 @@ class ExternalKnowledgeBindings(Base): dataset_id = mapped_column(StringUUID, nullable=False) external_knowledge_id = mapped_column(db.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetAutoDisableLog(Base): @@ -1090,8 +1092,10 @@ class DatasetAutoDisableLog(Base): tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + notified: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class RateLimitLog(Base): @@ -1104,9 +1108,11 @@ class RateLimitLog(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - subscription_plan = mapped_column(db.String(255), nullable=False) - operation = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False) + operation: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class DatasetMetadata(Base): @@ -1120,10 +1126,14 @@ class DatasetMetadata(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) created_by = mapped_column(StringUUID, nullable=False) updated_by = mapped_column(StringUUID, nullable=True) @@ -1143,5 +1153,5 @@ class DatasetMetadataBinding(Base): dataset_id = mapped_column(StringUUID, nullable=False) metadata_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 9f6d51b315..fba0d692eb 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: import sqlalchemy as sa from flask import request from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text +from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -37,7 +37,7 @@ class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = mapped_column(db.String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -73,15 +73,15 @@ class App(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) - mode: Mapped[str] = mapped_column(db.String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji - icon = db.Column(db.String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(db.String(255)) + mode: Mapped[str] = mapped_column(String(255)) + icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon = db.Column(String(255)) + icon_background: Mapped[Optional[str]] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) enable_site: Mapped[bool] = mapped_column(db.Boolean) enable_api: Mapped[bool] = mapped_column(db.Boolean) api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) @@ -306,8 +306,8 @@ class AppModelConfig(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) + provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) configs = mapped_column(db.JSON, nullable=True) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -321,12 +321,12 @@ class AppModelConfig(Base): more_like_this = mapped_column(db.Text) model = mapped_column(db.Text) user_input_form = mapped_column(db.Text) - dataset_query_variable = mapped_column(db.String(255)) + dataset_query_variable = mapped_column(String(255)) pre_prompt = mapped_column(db.Text) agent_mode = mapped_column(db.Text) sensitive_word_avoidance = mapped_column(db.Text) retriever_resource = mapped_column(db.Text) - prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) + prompt_type = mapped_column(String(255), nullable=False, server_default=db.text("'simple'::character varying")) chat_prompt_config = mapped_column(db.Text) completion_prompt_config = mapped_column(db.Text) dataset_configs = mapped_column(db.Text) @@ -561,14 +561,14 @@ class RecommendedApp(Base): id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) description = mapped_column(db.JSON, nullable=False) - copyright = mapped_column(db.String(255), nullable=False) - privacy_policy = mapped_column(db.String(255), nullable=False) + copyright: Mapped[str] = mapped_column(String(255), nullable=False) + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - category = mapped_column(db.String(255), nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_listed = mapped_column(db.Boolean, nullable=False, default=True) - install_count = mapped_column(db.Integer, nullable=False, default=0) - language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) + category: Mapped[str] = mapped_column(String(255), nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + is_listed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=True) + install_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + language = mapped_column(String(255), nullable=False, server_default=db.text("'en-US'::character varying")) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -591,8 +591,8 @@ class InstalledApp(Base): tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) app_owner_tenant_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + is_pinned: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = mapped_column(db.DateTime, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -617,26 +617,26 @@ class Conversation(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) app_model_config_id = mapped_column(StringUUID, nullable=True) - model_provider = mapped_column(db.String(255), nullable=True) + model_provider = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(db.Text) - model_id = mapped_column(db.String(255), nullable=True) - mode: Mapped[str] = mapped_column(db.String(255)) - name = mapped_column(db.String(255), nullable=False) + model_id = mapped_column(String(255), nullable=True) + mode: Mapped[str] = mapped_column(String(255)) + name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) introduction = mapped_column(db.Text) system_instruction = mapped_column(db.Text) - system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - status = mapped_column(db.String(255), nullable=False) + system_instruction_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + status: Mapped[str] = mapped_column(String(255), nullable=False) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(db.String(255), nullable=True) + invoke_from = mapped_column(String(255), nullable=True) # ref: ConversationSource. - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) read_at = mapped_column(db.DateTime) @@ -650,7 +650,7 @@ class Conversation(Base): "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def inputs(self): @@ -894,8 +894,8 @@ class Message(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - model_provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) + model_provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(db.Text) conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -911,17 +911,17 @@ class Message(Base): parent_message_id = mapped_column(StringUUID, nullable=True) provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) error = mapped_column(db.Text) message_metadata = mapped_column(db.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - from_source = mapped_column(db.String(255), nullable=False) + invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + agent_based: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property @@ -1238,9 +1238,9 @@ class MessageFeedback(Base): app_id = mapped_column(StringUUID, nullable=False) conversation_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) - rating = mapped_column(db.String(255), nullable=False) + rating: Mapped[str] = mapped_column(String(255), nullable=False) content = mapped_column(db.Text) - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1298,12 +1298,12 @@ class MessageFile(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1323,7 +1323,7 @@ class MessageAnnotation(Base): message_id: Mapped[Optional[str]] = mapped_column(StringUUID) question = db.Column(db.Text, nullable=True) content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) account_id = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1415,10 +1415,10 @@ class OperationLog(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - action = mapped_column(db.String(255), nullable=False) + action: Mapped[str] = mapped_column(String(255), nullable=False) content = mapped_column(db.JSON) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_ip = mapped_column(db.String(255), nullable=False) + created_ip: Mapped[str] = mapped_column(String(255), nullable=False) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1433,10 +1433,10 @@ class EndUser(Base, UserMixin): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(255), nullable=False) - external_user_id = mapped_column(db.String(255), nullable=True) - name = mapped_column(db.String(255)) - is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + external_user_id = mapped_column(String(255), nullable=True) + name = mapped_column(String(255)) + is_anonymous: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) session_id: Mapped[str] = mapped_column() created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1452,10 +1452,10 @@ class AppMCPServer(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) - server_code = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) + server_code: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) parameters = mapped_column(db.Text, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1485,28 +1485,28 @@ class Site(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - title = mapped_column(db.String(255), nullable=False) - icon_type = mapped_column(db.String(255), nullable=True) - icon = mapped_column(db.String(255)) - icon_background = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(String(255), nullable=False) + icon_type = mapped_column(String(255), nullable=True) + icon = mapped_column(String(255)) + icon_background = mapped_column(String(255)) description = mapped_column(db.Text) - default_language = mapped_column(db.String(255), nullable=False) - chat_color_theme = mapped_column(db.String(255)) - chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - copyright = mapped_column(db.String(255)) - privacy_policy = mapped_column(db.String(255)) - show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + default_language: Mapped[str] = mapped_column(String(255), nullable=False) + chat_color_theme = mapped_column(String(255)) + chat_color_theme_inverted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + copyright = mapped_column(String(255)) + privacy_policy = mapped_column(String(255)) + show_workflow_steps: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") - customize_domain = mapped_column(db.String(255)) - customize_token_strategy = mapped_column(db.String(255), nullable=False) - prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + customize_domain = mapped_column(String(255)) + customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + prompt_public: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - code = mapped_column(db.String(255)) + code = mapped_column(String(255)) @property def custom_disclaimer(self): @@ -1544,8 +1544,8 @@ class ApiToken(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - token = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(db.DateTime, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1567,21 +1567,21 @@ class UploadFile(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False) - key: Mapped[str] = mapped_column(db.String(255), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + storage_type: Mapped[str] = mapped_column(String(255), nullable=False) + key: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) size: Mapped[int] = mapped_column(db.Integer, nullable=False) - extension: Mapped[str] = mapped_column(db.String(255), nullable=False) - mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True) + extension: Mapped[str] = mapped_column(String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(String(255), nullable=True) created_by_role: Mapped[str] = mapped_column( - db.String(255), nullable=False, server_default=db.text("'account'::character varying") + String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( @@ -1630,10 +1630,10 @@ class ApiRequest(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) api_token_id = mapped_column(StringUUID, nullable=False) - path = mapped_column(db.String(255), nullable=False) + path: Mapped[str] = mapped_column(String(255), nullable=False) request = mapped_column(db.Text, nullable=True) response = mapped_column(db.Text, nullable=True) - ip = mapped_column(db.String(255), nullable=False) + ip: Mapped[str] = mapped_column(String(255), nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1646,7 +1646,7 @@ class MessageChain(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) input = mapped_column(db.Text, nullable=True) output = mapped_column(db.Text, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1663,7 +1663,7 @@ class MessageAgentThought(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) message_chain_id = mapped_column(StringUUID, nullable=True) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) thought = mapped_column(db.Text, nullable=True) tool = mapped_column(db.Text, nullable=True) tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) @@ -1673,19 +1673,19 @@ class MessageAgentThought(Base): # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design tool_process_data = mapped_column(db.Text, nullable=True) message = mapped_column(db.Text, nullable=True) - message_token = mapped_column(db.Integer, nullable=True) + message_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) message_unit_price = mapped_column(db.Numeric, nullable=True) message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_files = mapped_column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) - answer_token = mapped_column(db.Integer, nullable=True) + answer_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) answer_unit_price = mapped_column(db.Numeric, nullable=True) answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens = mapped_column(db.Integer, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) total_price = mapped_column(db.Numeric, nullable=True) - currency = mapped_column(db.String, nullable=True) - latency = mapped_column(db.Float, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + currency = mapped_column(String, nullable=True) + latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1775,18 +1775,18 @@ class DatasetRetrieverResource(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) dataset_name = mapped_column(db.Text, nullable=False) document_id = mapped_column(StringUUID, nullable=True) document_name = mapped_column(db.Text, nullable=False) data_source_type = mapped_column(db.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score = mapped_column(db.Float, nullable=True) + score: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - segment_position = mapped_column(db.Integer, nullable=True) + hit_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + segment_position: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) index_node_hash = mapped_column(db.Text, nullable=True) retriever_from = mapped_column(db.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) @@ -1805,8 +1805,8 @@ class Tag(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - name = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1836,13 +1836,13 @@ class TraceAppConfig(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - tracing_provider = mapped_column(db.String(255), nullable=True) + tracing_provider = mapped_column(String(255), nullable=True) tracing_config = mapped_column(db.JSON, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/provider.py b/api/models/provider.py index 1e25f0c90f..7bfc249b0b 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from typing import Optional -from sqlalchemy import func, text +from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -56,22 +56,22 @@ class Provider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) provider_type: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'custom'::character varying") + String(40), nullable=False, server_default=text("'custom'::character varying") ) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( - db.String(40), nullable=True, server_default=text("''::character varying") + String(40), nullable=True, server_default=text("''::character varying") ) quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -113,13 +113,13 @@ class ProviderModel(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -131,11 +131,11 @@ class TenantDefaultModel(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -147,10 +147,10 @@ class TenantPreferredModelProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -162,22 +162,22 @@ class ProviderOrder(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + currency: Mapped[Optional[str]] = mapped_column(String(40)) total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) payment_status: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -193,13 +193,13 @@ class ProviderModelSetting(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -215,11 +215,11 @@ class LoadBalancingModelConfig(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 100e0d96ef..8191c874a4 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,8 +1,10 @@ import json +from datetime import datetime +from typing import Optional -from sqlalchemy import func +from sqlalchemy import DateTime, String, func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -20,12 +22,12 @@ class DataSourceOauthBinding(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - access_token = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + access_token: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) source_info = mapped_column(JSONB, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -38,12 +40,12 @@ class DataSourceApiKeyAuthBinding(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - category = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + category: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) credentials = mapped_column(db.Text, nullable=True) # JSON - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 3e5ebd2099..66a47ea4df 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Optional from celery import states # type: ignore +from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now @@ -16,22 +17,22 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) - task_id = mapped_column(db.String(155), unique=True) - status = mapped_column(db.String(50), default=states.PENDING) + task_id = mapped_column(String(155), unique=True) + status = mapped_column(String(50), default=states.PENDING) result = mapped_column(db.PickleType, nullable=True) date_done = mapped_column( - db.DateTime, + DateTime, default=lambda: naive_utc_now(), onupdate=lambda: naive_utc_now(), nullable=True, ) traceback = mapped_column(db.Text, nullable=True) - name = mapped_column(db.String(155), nullable=True) + name = mapped_column(String(155), nullable=True) args = mapped_column(db.LargeBinary, nullable=True) kwargs = mapped_column(db.LargeBinary, nullable=True) - worker = mapped_column(db.String(155), nullable=True) - retries = mapped_column(db.Integer, nullable=True) - queue = mapped_column(db.String(155), nullable=True) + worker = mapped_column(String(155), nullable=True) + retries: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + queue = mapped_column(String(155), nullable=True) class CeleryTaskSet(Base): @@ -42,6 +43,6 @@ class CeleryTaskSet(Base): id: Mapped[int] = mapped_column( db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) - taskset_id = mapped_column(db.String(155), unique=True) + taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 68f4211e59..1491cd90ce 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, func +from sqlalchemy import ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column from core.file import helpers as file_helpers @@ -30,8 +30,8 @@ class ToolOAuthSystemClient(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -47,8 +47,8 @@ class ToolOAuthTenantClient(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -72,26 +72,26 @@ class BuiltinToolProvider(Base): # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column( - db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # name of the tool provider - provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + provider: Mapped[str] = mapped_column(String(256), nullable=False) # credential of the tool provider encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + String(32), nullable=False, server_default=db.text("'api-key'::character varying") ) expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) @@ -113,12 +113,12 @@ class ApiToolProvider(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) + name = mapped_column(String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon - icon = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema = mapped_column(db.Text, nullable=False) - schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False) + schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) # who created this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id @@ -130,12 +130,12 @@ class ApiToolProvider(Base): # json format credentials credentials_str = mapped_column(db.Text, nullable=False) # privacy policy - privacy_policy = mapped_column(db.String(255), nullable=True) + privacy_policy = mapped_column(String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -173,11 +173,11 @@ class ToolLabelBinding(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id - tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) + tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + tool_type: Mapped[str] = mapped_column(String(40), nullable=False) # label name - label_name: Mapped[str] = mapped_column(db.String(40), nullable=False) + label_name: Mapped[str] = mapped_column(String(40), nullable=False) class WorkflowToolProvider(Base): @@ -194,15 +194,15 @@ class WorkflowToolProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider - label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # icon - icon: Mapped[str] = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # app id of the workflow provider app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # version of the workflow provider - version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id @@ -212,13 +212,13 @@ class WorkflowToolProvider(Base): # parameter configuration parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) @property @@ -253,15 +253,15 @@ class MCPToolProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the mcp provider - name: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_identifier: Mapped[str] = mapped_column(String(64), nullable=False) # encrypted url of the mcp provider server_url: Mapped[str] = mapped_column(db.Text, nullable=False) # hash of server_url for uniqueness check - server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False) # icon of the mcp provider - icon: Mapped[str] = mapped_column(db.String(255), nullable=True) + icon: Mapped[str] = mapped_column(String(255), nullable=True) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who created this tool @@ -273,10 +273,10 @@ class MCPToolProvider(Base): # tools tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) def load_user(self) -> Account | None: @@ -355,11 +355,11 @@ class ToolModelInvoke(Base): # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # provider - provider = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type = mapped_column(db.String(40), nullable=False) + tool_type = mapped_column(String(40), nullable=False) # tool name - tool_name = mapped_column(db.String(128), nullable=False) + tool_name = mapped_column(String(128), nullable=False) # invoke parameters model_parameters = mapped_column(db.Text, nullable=False) # prompt messages @@ -367,15 +367,15 @@ class ToolModelInvoke(Base): # invoke response model_response = mapped_column(db.Text, nullable=False) - prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + prompt_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -402,8 +402,8 @@ class ToolConversationVariables(Base): # variables pool variables_str = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> Any: @@ -429,11 +429,11 @@ class ToolFile(Base): # conversation id conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # file key - file_key: Mapped[str] = mapped_column(db.String(255), nullable=False) + file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type - mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) + mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + original_url: Mapped[str] = mapped_column(String(2048), nullable=True) # name name: Mapped[str] = mapped_column(default="") # size @@ -465,13 +465,13 @@ class DeprecatedPublishedAppTool(Base): # to describe this parameter to llm, we need this field query_description = mapped_column(db.Text, nullable=False) # query name, the name of the query parameter - query_name = mapped_column(db.String(40), nullable=False) + query_name = mapped_column(String(40), nullable=False) # name of the tool provider - tool_name = mapped_column(db.String(40), nullable=False) + tool_name = mapped_column(String(40), nullable=False) # author - author = mapped_column(db.String(40), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + author = mapped_column(String(40), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index ce00f4010f..1bf9b5c761 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,6 @@ -from sqlalchemy import func +from datetime import datetime + +from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -19,10 +21,10 @@ class SavedMessage(Base): app_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=db.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -40,7 +42,7 @@ class PinnedConversation(Base): app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=db.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index d89db6c7da..6c7d061bb4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 from flask_login import current_user -from sqlalchemy import orm +from sqlalchemy import DateTime, orm from core.file.constants import maybe_file_object from core.file.models import File @@ -25,7 +25,7 @@ if TYPE_CHECKING: from models.model import AppMode import sqlalchemy as sa -from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func +from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, declared_attr, mapped_column from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE @@ -124,17 +124,17 @@ class Workflow(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - version: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(default="", server_default="") marked_comment: Mapped[str] = mapped_column(default="", server_default="") graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=naive_utc_now(), server_onupdate=func.current_timestamp(), @@ -500,21 +500,21 @@ class WorkflowRun(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(db.String(255)) - triggered_from: Mapped[str] = mapped_column(db.String(255)) - version: Mapped[str] = mapped_column(db.String(255)) + type: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) + version: Mapped[str] = mapped_column(String(255)) graph: Mapped[Optional[str]] = mapped_column(db.Text) inputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded + status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user + created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) @property @@ -708,25 +708,25 @@ class WorkflowNodeExecutionModel(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(db.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_id: Mapped[str] = mapped_column(db.String(255)) - node_type: Mapped[str] = mapped_column(db.String(255)) - title: Mapped[str] = mapped_column(db.String(255)) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_id: Mapped[str] = mapped_column(String(255)) + node_type: Mapped[str] = mapped_column(String(255)) + title: Mapped[str] = mapped_column(String(255)) inputs: Mapped[Optional[str]] = mapped_column(db.Text) process_data: Mapped[Optional[str]] = mapped_column(db.Text) outputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) + status: Mapped[str] = mapped_column(String(255)) error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) @property def created_by_account(self): @@ -843,10 +843,10 @@ class WorkflowAppLog(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -873,10 +873,10 @@ class ConversationVariable(Base): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) data: Mapped[str] = mapped_column(db.Text, nullable=False) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True + DateTime, nullable=False, server_default=func.current_timestamp(), index=True ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: @@ -936,14 +936,14 @@ class WorkflowDraftVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) created_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), @@ -958,7 +958,7 @@ class WorkflowDraftVariable(Base): # # If it's not edited after creation, its value is `None`. last_edited_at: Mapped[datetime | None] = mapped_column( - db.DateTime, + DateTime, nullable=True, default=None, ) diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index a05e1358ed..f0d3bed057 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -1,8 +1,8 @@ import logging from datetime import datetime -from urllib.parse import urlparse import click +from kombu.utils.url import parse_url # type: ignore from redis import Redis import app @@ -10,16 +10,13 @@ from configs import dify_config from extensions.ext_database import db from libs.email_i18n import EmailType, get_email_i18n_service -# Create a dedicated Redis connection (using the same configuration as Celery) -celery_broker_url = dify_config.CELERY_BROKER_URL - -parsed = urlparse(celery_broker_url) -host = parsed.hostname or "localhost" -port = parsed.port or 6379 -password = parsed.password or None -redis_db = parsed.path.strip("/") or "1" # type: ignore - -celery_redis = Redis(host=host, port=port, password=password, db=redis_db) +redis_config = parse_url(dify_config.CELERY_BROKER_URL) +celery_redis = Redis( + host=redis_config.get("hostname") or "localhost", + port=redis_config.get("port") or 6379, + password=redis_config.get("password") or None, + db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, +) @app.celery.task(queue="monitor") diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 3239af998e..b7a047914e 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -266,6 +266,54 @@ class AppAnnotationService: annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id ) + @classmethod + def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): + # get app info + app = ( + db.session.query(App) + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + # Fetch annotations and their settings in a single query + annotations_to_delete = ( + db.session.query(MessageAnnotation, AppAnnotationSetting) + .outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id) + .filter(MessageAnnotation.id.in_(annotation_ids)) + .all() + ) + + if not annotations_to_delete: + return {"deleted_count": 0} + + # Step 1: Extract IDs for bulk operations + annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] + + # Step 2: Bulk delete hit histories in a single query + db.session.query(AppAnnotationHitHistory).filter( + AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) + ).delete(synchronize_session=False) + + # Step 3: Trigger async tasks for search index deletion + for annotation, annotation_setting in annotations_to_delete: + if annotation_setting: + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id + ) + + # Step 4: Bulk delete annotations in a single query + deleted_count = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.id.in_(annotation_ids_to_delete)) + .delete(synchronize_session=False) + ) + + db.session.commit() + return {"deleted_count": deleted_count} + @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info @@ -280,7 +328,7 @@ class AppAnnotationService: try: # Skip the first row - df = pd.read_csv(file) + df = pd.read_csv(file, dtype=str) result = [] for index, row in df.iterrows(): content = {"question": row.iloc[0], "answer": row.iloc[1]} @@ -452,6 +500,11 @@ class AppAnnotationService: if not app: raise NotFound("App not found") + # if annotation reply is enabled, delete annotation index + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + ) + annotations_query = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id) for annotation in annotations_query.yield_per(100): annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).filter( @@ -460,6 +513,12 @@ class AppAnnotationService: for annotation_hit_history in annotation_hit_histories_query.yield_per(100): db.session.delete(annotation_hit_history) + # if annotation reply is enabled, delete annotation index + if app_annotation_setting: + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) + db.session.delete(annotation) db.session.commit() diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 206c832a20..692a3639cd 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,12 +1,15 @@ from collections.abc import Callable, Sequence -from typing import Optional, Union +from typing import Any, Optional, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator +from core.variables.types import SegmentType +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db +from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ConversationVariable @@ -15,6 +18,7 @@ from models.model import App, Conversation, EndUser, Message from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, LastConversationNotExistsError, ) from services.errors.message import MessageNotExistsError @@ -220,3 +224,82 @@ class ConversationService: ] return InfiniteScrollPagination(variables, limit, has_more) + + @classmethod + def update_conversation_variable( + cls, + app_model: App, + conversation_id: str, + variable_id: str, + user: Optional[Union[Account, EndUser]], + new_value: Any, + ) -> dict: + """ + Update a conversation variable's value. + + Args: + app_model: The app model + conversation_id: The conversation ID + variable_id: The variable ID to update + user: The user (Account or EndUser) + new_value: The new value for the variable + + Returns: + Dictionary containing the updated variable information + + Raises: + ConversationNotExistsError: If the conversation doesn't exist + ConversationVariableNotExistsError: If the variable doesn't exist + ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type + """ + # Verify conversation exists and user has access + conversation = cls.get_conversation(app_model, conversation_id, user) + + # Get the existing conversation variable + stmt = ( + select(ConversationVariable) + .where(ConversationVariable.app_id == app_model.id) + .where(ConversationVariable.conversation_id == conversation.id) + .where(ConversationVariable.id == variable_id) + ) + + with Session(db.engine) as session: + existing_variable = session.scalar(stmt) + if not existing_variable: + raise ConversationVariableNotExistsError() + + # Convert existing variable to Variable object + current_variable = existing_variable.to_variable() + + # Validate that the new value type matches the expected variable type + expected_type = SegmentType(current_variable.value_type) + if not expected_type.is_valid(new_value): + inferred_type = SegmentType.infer_segment_type(new_value) + raise ConversationVariableTypeMismatchError( + f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, " + f"but got {inferred_type.value if inferred_type else 'unknown'} type" + ) + + # Create updated variable with new value only, preserving everything else + updated_variable_dict = { + "id": current_variable.id, + "name": current_variable.name, + "description": current_variable.description, + "value_type": current_variable.value_type, + "value": new_value, + "selector": current_variable.selector, + } + + updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) + + # Use the conversation variable updater to persist the changes + updater = conversation_variable_updater_factory() + updater.update(conversation_id, updated_variable) + updater.flush() + + # Return the updated variable data + return { + "created_at": existing_variable.created_at, + "updated_at": naive_utc_now(), # Update timestamp + **updated_variable.model_dump(), + } diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 1280399990..da475a18f8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2040,6 +2040,7 @@ class SegmentService: db.session.add(segment_document) # update document word count + assert document.word_count is not None document.word_count += segment_document.word_count db.session.add(document) db.session.commit() @@ -2124,6 +2125,7 @@ class SegmentService: else: keywords_list.append(None) # update document word count + assert document.word_count is not None document.word_count += increment_word_count db.session.add(document) try: @@ -2185,6 +2187,7 @@ class SegmentService: db.session.commit() # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task @@ -2260,6 +2263,7 @@ class SegmentService: word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) db.session.add(segment) @@ -2323,6 +2327,7 @@ class SegmentService: delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count + assert document.word_count is not None document.word_count -= segment.word_count db.session.add(document) db.session.commit() diff --git a/api/services/errors/conversation.py b/api/services/errors/conversation.py index f8051e3417..a123f99b59 100644 --- a/api/services/errors/conversation.py +++ b/api/services/errors/conversation.py @@ -15,3 +15,7 @@ class ConversationCompletedError(Exception): class ConversationVariableNotExistsError(BaseServiceError): pass + + +class ConversationVariableTypeMismatchError(BaseServiceError): + pass diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index b7af03e91f..2f1babba6f 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -46,9 +46,9 @@ class ExternalDatasetService: def validate_api_list(cls, api_settings: dict): if not api_settings: raise ValueError("api list is empty") - if "endpoint" not in api_settings and not api_settings["endpoint"]: + if not api_settings.get("endpoint"): raise ValueError("endpoint is required") - if "api_key" not in api_settings and not api_settings["api_key"]: + if not api_settings.get("api_key"): raise ValueError("api_key is required") @staticmethod diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index b8c56f9355..7311123985 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -509,10 +509,10 @@ class BuiltinToolManageService: oauth_params = encrypter.decrypt(user_client.oauth_params) return oauth_params - # only verified provider can use custom oauth client - is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( - tenant_id, provider.plugin_unique_identifier - ) + # only verified provider can use official oauth client + is_verified = not isinstance( + provider_controller, PluginToolProviderController + ) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if not is_verified: return oauth_params diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index abf6824d73..afcf1f7621 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -185,7 +185,7 @@ class WorkflowConverter: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(new_app_mode).value, - version="draft", + version=Workflow.VERSION_DRAFT, graph=json.dumps(graph), features=json.dumps(features), created_by=account_id, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e9f21fc5f1..8588144980 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -105,7 +105,9 @@ class WorkflowService: workflow = ( db.session.query(Workflow) .where( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == Workflow.VERSION_DRAFT, ) .first() ) @@ -219,7 +221,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - version="draft", + version=Workflow.VERSION_DRAFT, graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, @@ -257,7 +259,7 @@ class WorkflowService: draft_workflow_stmt = select(Workflow).where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, - Workflow.version == "draft", + Workflow.version == Workflow.VERSION_DRAFT, ) draft_workflow = session.scalar(draft_workflow_stmt) if not draft_workflow: @@ -382,9 +384,9 @@ class WorkflowService: tenant_id=app_model.tenant_id, ) - eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) - if eclosing_node_type_and_id: - _, enclosing_node_id = eclosing_node_type_and_id + enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + if enclosing_node_type_and_id: + _, enclosing_node_id = enclosing_node_type_and_id else: enclosing_node_id = None @@ -644,7 +646,7 @@ class WorkflowService: raise ValueError(f"Workflow with ID {workflow_id} not found") # Check if workflow is a draft version - if workflow.version == "draft": + if workflow.version == Workflow.VERSION_DRAFT: raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") # Check if this workflow is currently referenced by an app diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index a2105f8a9d..c5ee4ce3f9 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -32,6 +32,7 @@ def add_document_to_index_task(dataset_document_id: str): return if dataset_document.indexing_status != "completed": + db.session.close() return indexing_cache_key = f"document_{dataset_document.id}_indexing" @@ -112,3 +113,4 @@ def add_document_to_index_task(dataset_document_id: str): db.session.commit() finally: redis_client.delete(indexing_cache_key) + db.session.close() diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 714e30acc3..dee43cd854 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -134,6 +134,7 @@ def batch_create_segment_to_index_task( db.session.add(segment_document) document_segments.append(segment_document) # update document word count + assert dataset_document.word_count is not None dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index a8839ffc17..543a512851 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -31,6 +31,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] return if segment.status != "waiting": + db.session.close() return indexing_cache_key = f"segment_{segment.id}_indexing" diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 56f330b964..993b2ac404 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -113,3 +113,5 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logging.info(click.style(str(ex), fg="yellow")) except Exception: logging.exception("document_indexing_sync_task failed, document_id: %s", document_id) + finally: + db.session.close() diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 2576d7b051..26b41aff2e 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -95,8 +95,8 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(retry_indexing_cache_key) logging.exception("retry_document_indexing_task failed, document_id: %s", document_id) - end_at = time.perf_counter() - logging.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + end_at = time.perf_counter() + logging.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) except Exception as e: logging.exception( "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py index 2a0c1bb038..a5ff5b9e82 100644 --- a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -11,7 +11,9 @@ class ElasticSearchVectorTest(AbstractVectorTest): self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = ElasticSearchVector( index_name=self.collection_name.lower(), - config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"), + config=ElasticSearchConfig( + use_cloud=False, host="http://localhost", port="9200", username="elastic", password="elastic" + ), attributes=self.attributes, ) diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py index da549af1b6..aebf3fbda1 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py @@ -2,6 +2,7 @@ import os import uuid import tablestore +from _pytest.python_api import approx from core.rag.datasource.vdb.tablestore.tablestore_vector import ( TableStoreConfig, @@ -16,7 +17,7 @@ from tests.integration_tests.vdb.test_vector_store import ( class TableStoreVectorTest(AbstractVectorTest): - def __init__(self): + def __init__(self, normalize_full_text_score: bool = False): super().__init__() self.vector = TableStoreVector( collection_name=self.collection_name, @@ -25,6 +26,7 @@ class TableStoreVectorTest(AbstractVectorTest): instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"), access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"), access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"), + normalize_full_text_bm25_score=normalize_full_text_score, ), ) @@ -64,7 +66,21 @@ class TableStoreVectorTest(AbstractVectorTest): docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id]) assert len(docs) == 1 assert docs[0].metadata["doc_id"] == self.example_doc_id - assert not hasattr(docs[0], "score") + if self.vector._config.normalize_full_text_bm25_score: + assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3) + else: + assert docs[0].metadata.get("score") is None + + # return none if normalize_full_text_score=true and score_threshold > 0 + docs = self.vector.search_by_full_text( + get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5 + ) + if self.vector._config.normalize_full_text_bm25_score: + assert len(docs) == 0 + else: + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert docs[0].metadata.get("score") is None docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())]) assert len(docs) == 0 @@ -80,3 +96,5 @@ class TableStoreVectorTest(AbstractVectorTest): def test_tablestore_vector(setup_mock_redis): TableStoreVectorTest().run_all_tests() + TableStoreVectorTest(normalize_full_text_score=True).run_all_tests() + TableStoreVectorTest(normalize_full_text_score=False).run_all_tests() diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index e9d4ee1935..0ae6a09f5b 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -1,5 +1,6 @@ import os +import pytest from flask import Flask from packaging.version import Version from yarl import URL @@ -137,3 +138,61 @@ def test_db_extras_options_merging(monkeypatch): options = engine_options["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options + + +@pytest.mark.parametrize( + ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), + [ + ("redis://localhost:6379/1", "localhost", 6379, None, None, "1"), + ("redis://:password@localhost:6379/1", "localhost", 6379, None, "password", "1"), + ("redis://:mypass%23123@localhost:6379/1", "localhost", 6379, None, "mypass#123", "1"), + ("redis://user:pass%40word@redis-host:6380/2", "redis-host", 6380, "user", "pass@word", "2"), + ("redis://admin:complex%23pass%40word@127.0.0.1:6379/0", "127.0.0.1", 6379, "admin", "complex#pass@word", "0"), + ( + "redis://user%40domain:secret%23123@redis.example.com:6380/3", + "redis.example.com", + 6380, + "user@domain", + "secret#123", + "3", + ), + # Password containing %23 substring (double encoding scenario) + ("redis://:mypass%2523@localhost:6379/1", "localhost", 6379, None, "mypass%23", "1"), + # Username and password both containing encoded characters + ("redis://user%2525%40:pass%2523@localhost:6379/1", "localhost", 6379, "user%25@", "pass%23", "1"), + ], +) +def test_celery_broker_url_with_special_chars_password( + monkeypatch, broker_url, expected_host, expected_port, expected_username, expected_password, expected_db +): + """Test that CELERY_BROKER_URL with various formats are handled correctly.""" + from kombu.utils.url import parse_url + + # clear system environment variables + os.environ.clear() + + # Set up basic required environment variables (following existing pattern) + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + # Set the CELERY_BROKER_URL to test + monkeypatch.setenv("CELERY_BROKER_URL", broker_url) + + # Create config and verify the URL is stored correctly + config = DifyConfig() + assert broker_url == config.CELERY_BROKER_URL + + # Test actual parsing behavior using kombu's parse_url (same as production) + redis_config = parse_url(config.CELERY_BROKER_URL) + + # Verify the parsing results match expectations (using kombu's field names) + assert redis_config["hostname"] == expected_host + assert redis_config["port"] == expected_port + assert redis_config["userid"] == expected_username # kombu uses 'userid' not 'username' + assert redis_config["password"] == expected_password + assert redis_config["virtual_host"] == expected_db # kombu uses 'virtual_host' not 'db' diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py new file mode 100644 index 0000000000..cb5562d345 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -0,0 +1,278 @@ +import io +from unittest.mock import patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.common.errors import FilenameNotExistsError +from controllers.console.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from services.errors.file import FileTooLargeError as ServiceFileTooLargeError +from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + +class TestFileUploadSecurity: + """Test file upload security logic without complex framework setup""" + + # Test 1: Basic file validation + def test_should_validate_file_presence(self): + """Test that missing file is detected""" + from flask import Flask, request + + app = Flask(__name__) + + with app.test_request_context(method="POST", data={}): + # Simulate the check in FileApi.post() + if "file" not in request.files: + with pytest.raises(NoFileUploadedError): + raise NoFileUploadedError() + + def test_should_validate_multiple_files(self): + """Test that multiple files are rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = { + "file": (io.BytesIO(b"content1"), "file1.txt", "text/plain"), + "file2": (io.BytesIO(b"content2"), "file2.txt", "text/plain"), + } + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + # Simulate the check in FileApi.post() + if len(request.files) > 1: + with pytest.raises(TooManyFilesError): + raise TooManyFilesError() + + def test_should_validate_empty_filename(self): + """Test that empty filename is rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = {"file": (io.BytesIO(b"content"), "", "text/plain")} + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + file = request.files["file"] + if not file.filename: + with pytest.raises(FilenameNotExistsError): + raise FilenameNotExistsError + + # Test 2: Security - Filename sanitization + def test_should_detect_path_traversal_in_filename(self): + """Test protection against directory traversal attacks""" + dangerous_filenames = [ + "../../../etc/passwd", + "..\\..\\windows\\system32\\config\\sam", + "../../../../etc/shadow", + "./../../../sensitive.txt", + ] + + for filename in dangerous_filenames: + # Any filename containing .. should be considered dangerous + assert ".." in filename, f"Filename {filename} should be detected as path traversal" + + def test_should_detect_null_byte_injection(self): + """Test protection against null byte injection""" + dangerous_filenames = [ + "file.jpg\x00.php", + "document.pdf\x00.exe", + "image.png\x00.sh", + ] + + for filename in dangerous_filenames: + # Null bytes should be detected + assert "\x00" in filename, f"Filename {filename} should be detected as null byte injection" + + def test_should_sanitize_special_characters(self): + """Test that special characters in filenames are handled safely""" + # Characters that could be problematic in various contexts + dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"] + + for char in dangerous_chars: + filename = f"file{char}name.txt" + # These characters should be detected or sanitized + assert any(c in filename for c in dangerous_chars) + + # Test 3: Permission validation + def test_should_validate_dataset_permissions(self): + """Test dataset upload permission logic""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = "datasets" + + # Simulate the permission check in FileApi.post() + if source == "datasets" and not user.is_dataset_editor: + with pytest.raises(Forbidden): + raise Forbidden() + + def test_should_allow_general_upload_without_permission(self): + """Test general upload doesn't require dataset permission""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = None # General upload + + # This should not raise an exception + if source == "datasets" and not user.is_dataset_editor: + raise Forbidden() + # Test passes if no exception is raised + + # Test 4: Service error handling + @patch("services.file_service.FileService.upload_file") + def test_should_handle_file_too_large_error(self, mock_upload): + """Test that service FileTooLargeError is properly converted""" + mock_upload.side_effect = ServiceFileTooLargeError("File too large") + + try: + mock_upload(filename="test.txt", content=b"data", mimetype="text/plain", user=None, source=None) + except ServiceFileTooLargeError as e: + # Simulate the error conversion in FileApi.post() + with pytest.raises(FileTooLargeError): + raise FileTooLargeError(e.description) + + @patch("services.file_service.FileService.upload_file") + def test_should_handle_unsupported_file_type_error(self, mock_upload): + """Test that service UnsupportedFileTypeError is properly converted""" + mock_upload.side_effect = ServiceUnsupportedFileTypeError() + + try: + mock_upload( + filename="test.exe", content=b"data", mimetype="application/octet-stream", user=None, source=None + ) + except ServiceUnsupportedFileTypeError: + # Simulate the error conversion in FileApi.post() + with pytest.raises(UnsupportedFileTypeError): + raise UnsupportedFileTypeError() + + # Test 5: File type security + def test_should_identify_dangerous_file_extensions(self): + """Test detection of potentially dangerous file extensions""" + dangerous_extensions = [ + ".php", + ".PHP", + ".pHp", # PHP files (case variations) + ".exe", + ".EXE", # Executables + ".sh", + ".SH", # Shell scripts + ".bat", + ".BAT", # Batch files + ".cmd", + ".CMD", # Command files + ".ps1", + ".PS1", # PowerShell + ".jar", + ".JAR", # Java archives + ".vbs", + ".VBS", # VBScript + ] + + safe_extensions = [".txt", ".pdf", ".jpg", ".png", ".doc", ".docx"] + + # Just verify our test data is correct + for ext in dangerous_extensions: + assert ext.lower() in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + for ext in safe_extensions: + assert ext.lower() not in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + def test_should_detect_double_extensions(self): + """Test detection of double extension attacks""" + suspicious_filenames = [ + "image.jpg.php", + "document.pdf.exe", + "photo.png.sh", + "file.txt.bat", + ] + + for filename in suspicious_filenames: + # Check that these have multiple extensions + parts = filename.split(".") + assert len(parts) > 2, f"Filename {filename} should have multiple extensions" + + # Test 6: Configuration validation + def test_upload_configuration_structure(self): + """Test that upload configuration has correct structure""" + # Simulate the configuration returned by FileApi.get() + config = { + "file_size_limit": 15, + "batch_count_limit": 5, + "image_file_size_limit": 10, + "video_file_size_limit": 500, + "audio_file_size_limit": 50, + "workflow_file_upload_limit": 10, + } + + # Verify all required fields are present + required_fields = [ + "file_size_limit", + "batch_count_limit", + "image_file_size_limit", + "video_file_size_limit", + "audio_file_size_limit", + "workflow_file_upload_limit", + ] + + for field in required_fields: + assert field in config, f"Missing required field: {field}" + assert isinstance(config[field], int), f"Field {field} should be an integer" + assert config[field] > 0, f"Field {field} should be positive" + + # Test 7: Source parameter handling + def test_source_parameter_normalization(self): + """Test that source parameter is properly normalized""" + test_cases = [ + ("datasets", "datasets"), + ("other", None), + ("", None), + (None, None), + ] + + for input_source, expected in test_cases: + # Simulate the source normalization in FileApi.post() + source = "datasets" if input_source == "datasets" else None + if source not in ("datasets", None): + source = None + assert source == expected + + # Test 8: Boundary conditions + def test_should_handle_edge_case_file_sizes(self): + """Test handling of boundary file sizes""" + test_cases = [ + (0, "Empty file"), # 0 bytes + (1, "Single byte"), # 1 byte + (15 * 1024 * 1024 - 1, "Just under limit"), # Just under 15MB + (15 * 1024 * 1024, "At limit"), # Exactly 15MB + (15 * 1024 * 1024 + 1, "Just over limit"), # Just over 15MB + ] + + for size, description in test_cases: + # Just verify our test data + assert isinstance(size, int), f"{description}: Size should be integer" + assert size >= 0, f"{description}: Size should be non-negative" + + def test_should_handle_special_mime_types(self): + """Test handling of various MIME types""" + mime_type_tests = [ + ("application/octet-stream", "Generic binary"), + ("text/plain", "Plain text"), + ("image/jpeg", "JPEG image"), + ("application/pdf", "PDF document"), + ("", "Empty MIME type"), + (None, "None MIME type"), + ] + + for mime_type, description in mime_type_tests: + # Verify test data structure + if mime_type is not None: + assert isinstance(mime_type, str), f"{description}: MIME type should be string or None" diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 4bcc6cb605..1dc380ad0b 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -102,9 +102,14 @@ class TestPhoenixConfig: assert config.project == "default" def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1") - assert config.endpoint == "https://custom.phoenix.com" + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" class TestLangfuseConfig: @@ -118,7 +123,7 @@ class TestLangfuseConfig: assert config.host == "https://custom.langfuse.com" def test_valid_config_with_path(self): - host = host = "https://custom.langfuse.com/api/v1" + host = "https://custom.langfuse.com/api/v1" config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) assert config.public_key == "public_key" assert config.secret_key == "secret_key" @@ -368,13 +373,15 @@ class TestConfigIntegration: """Test that URL normalization works consistently across configs""" # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") - phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/") + phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") aliyun_config = AliyunConfig( license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" ) assert arize_config.endpoint == "https://arize.com" - assert phoenix_config.endpoint == "https://phoenix.com" + assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com" assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" def test_project_default_values(self): diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py new file mode 100644 index 0000000000..c4c7579e83 --- /dev/null +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -0,0 +1,189 @@ +from unittest.mock import Mock, patch + +import pytest +from flask_restful import reqparse +from werkzeug.exceptions import BadRequest + +from services.entities.knowledge_entities.knowledge_entities import MetadataArgs +from services.metadata_service import MetadataService + + +class TestMetadataBugCompleteValidation: + """Complete test suite to verify the metadata nullable bug and its fix.""" + + def test_1_pydantic_layer_validation(self): + """Test Layer 1: Pydantic model validation correctly rejects None values.""" + # Pydantic should reject None values for required fields + with pytest.raises((ValueError, TypeError)): + MetadataArgs(type=None, name=None) + + with pytest.raises((ValueError, TypeError)): + MetadataArgs(type="string", name=None) + + with pytest.raises((ValueError, TypeError)): + MetadataArgs(type=None, name="test") + + # Valid values should work + valid_args = MetadataArgs(type="string", name="test_name") + assert valid_args.type == "string" + assert valid_args.name == "test_name" + + def test_2_business_logic_layer_crashes_on_none(self): + """Test Layer 2: Business logic crashes when None values slip through.""" + # Create mock that bypasses Pydantic validation + mock_metadata_args = Mock() + mock_metadata_args.name = None + mock_metadata_args.type = "string" + + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # Should crash with TypeError + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) + + # Test update method as well + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.update_metadata_name("dataset-123", "metadata-456", None) + + def test_3_database_constraints_verification(self): + """Test Layer 3: Verify database model has nullable=False constraints.""" + from sqlalchemy import inspect + + from models.dataset import DatasetMetadata + + # Get table info + mapper = inspect(DatasetMetadata) + + # Check that type and name columns are not nullable + type_column = mapper.columns["type"] + name_column = mapper.columns["name"] + + assert type_column.nullable is False, "type column should be nullable=False" + assert name_column.nullable is False, "name column should be nullable=False" + + def test_4_fixed_api_layer_rejects_null(self, app): + """Test Layer 4: Fixed API configuration properly rejects null values.""" + # Test Console API create endpoint (fixed) + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + with pytest.raises(BadRequest): + parser.parse_args() + + # Test with just name being null + with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"): + with pytest.raises(BadRequest): + parser.parse_args() + + # Test with just type being null + with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"): + with pytest.raises(BadRequest): + parser.parse_args() + + def test_5_fixed_api_accepts_valid_values(self, app): + """Test that fixed API still accepts valid non-null values.""" + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + + with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"): + args = parser.parse_args() + assert args["type"] == "string" + assert args["name"] == "valid_name" + + def test_6_simulated_buggy_behavior(self, app): + """Test simulating the original buggy behavior with nullable=True.""" + # Simulate the old buggy configuration + buggy_parser = reqparse.RequestParser() + buggy_parser.add_argument("type", type=str, required=True, nullable=True, location="json") + buggy_parser.add_argument("name", type=str, required=True, nullable=True, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + # This would pass in the buggy version + args = buggy_parser.parse_args() + assert args["type"] is None + assert args["name"] is None + + # But would crash when trying to create MetadataArgs + with pytest.raises((ValueError, TypeError)): + MetadataArgs(**args) + + def test_7_end_to_end_validation_layers(self): + """Test all validation layers work together correctly.""" + # Layer 1: API should reject null at parameter level (with fix) + # Layer 2: Pydantic should reject null at model level + # Layer 3: Business logic expects non-null + # Layer 4: Database enforces non-null + + # Test that valid data flows through all layers + valid_data = {"type": "string", "name": "test_metadata"} + + # Should create valid Pydantic object + metadata_args = MetadataArgs(**valid_data) + assert metadata_args.type == "string" + assert metadata_args.name == "test_metadata" + + # Should not crash in business logic length check + assert len(metadata_args.name) <= 255 # This should not crash + assert len(metadata_args.type) > 0 # This should not crash + + def test_8_verify_specific_fix_locations(self): + """Verify that the specific locations mentioned in bug report are fixed.""" + # Read the actual files to verify fixes + import os + + # Console API create + console_create_file = "api/controllers/console/datasets/metadata.py" + if os.path.exists(console_create_file): + with open(console_create_file) as f: + content = f.read() + # Should contain nullable=False, not nullable=True + assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] + + # Service API create + service_create_file = "api/controllers/service_api/dataset/metadata.py" + if os.path.exists(service_create_file): + with open(service_create_file) as f: + content = f.read() + # Should contain nullable=False, not nullable=True + create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] + assert "nullable=True" not in create_api_section + + +class TestMetadataValidationSummary: + """Summary tests that demonstrate the complete validation architecture.""" + + def test_validation_layer_architecture(self): + """Document and test the 4-layer validation architecture.""" + # Layer 1: API Parameter Validation (Flask-RESTful reqparse) + # - Role: First line of defense, validates HTTP request parameters + # - Fixed: nullable=False ensures null values are rejected at API boundary + + # Layer 2: Pydantic Model Validation + # - Role: Validates data structure and types before business logic + # - Working: Required fields without Optional[] reject None values + + # Layer 3: Business Logic Validation + # - Role: Domain-specific validation (length checks, uniqueness, etc.) + # - Vulnerable: Direct len() calls crash on None values + + # Layer 4: Database Constraints + # - Role: Final data integrity enforcement + # - Working: nullable=False prevents None values in database + + # The bug was: Layer 1 allowed None, but Layers 2-4 expected non-None + # The fix: Make Layer 1 consistent with Layers 2-4 + + assert True # This test documents the architecture + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py new file mode 100644 index 0000000000..ef4d05c1d9 --- /dev/null +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -0,0 +1,108 @@ +from unittest.mock import Mock, patch + +import pytest +from flask_restful import reqparse + +from services.entities.knowledge_entities.knowledge_entities import MetadataArgs +from services.metadata_service import MetadataService + + +class TestMetadataNullableBug: + """Test case to reproduce the metadata nullable validation bug.""" + + def test_metadata_args_with_none_values_should_fail(self): + """Test that MetadataArgs validation should reject None values.""" + # This test demonstrates the expected behavior - should fail validation + with pytest.raises((ValueError, TypeError)): + # This should fail because Pydantic expects non-None values + MetadataArgs(type=None, name=None) + + def test_metadata_service_create_with_none_name_crashes(self): + """Test that MetadataService.create_metadata crashes when name is None.""" + # Mock the MetadataArgs to bypass Pydantic validation + mock_metadata_args = Mock() + mock_metadata_args.name = None # This will cause len() to crash + mock_metadata_args.type = "string" + + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # This should crash with TypeError when calling len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) + + def test_metadata_service_update_with_none_name_crashes(self): + """Test that MetadataService.update_metadata_name crashes when name is None.""" + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # This should crash with TypeError when calling len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.update_metadata_name("dataset-123", "metadata-456", None) + + def test_api_parser_accepts_null_values(self, app): + """Test that API parser configuration incorrectly accepts null values.""" + # Simulate the current API parser configuration + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=True, location="json") + + # Simulate request data with null values + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + # This should parse successfully due to nullable=True + args = parser.parse_args() + + # Verify that null values are accepted + assert args["type"] is None + assert args["name"] is None + + # This demonstrates the bug: API accepts None but business logic will crash + + def test_integration_bug_scenario(self, app): + """Test the complete bug scenario from API to service layer.""" + # Step 1: API parser accepts null values (current buggy behavior) + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=True, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + args = parser.parse_args() + + # Step 2: Try to create MetadataArgs with None values + # This should fail at Pydantic validation level + with pytest.raises((ValueError, TypeError)): + metadata_args = MetadataArgs(**args) + + # Step 3: If we bypass Pydantic (simulating the bug scenario) + # Move this outside the request context to avoid Flask-Login issues + mock_metadata_args = Mock() + mock_metadata_args.name = None # From args["name"] + mock_metadata_args.type = None # From args["type"] + + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # Step 4: Service layer crashes on len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) + + def test_correct_nullable_false_configuration_works(self, app): + """Test that the correct nullable=False configuration works as expected.""" + # This tests the FIXED configuration + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + # This should fail with BadRequest due to nullable=False + from werkzeug.exceptions import BadRequest + + with pytest.raises(BadRequest): + parser.parse_args() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/docker/.env.example b/docker/.env.example index 9d15ba53d3..13cac189aa 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -583,6 +583,17 @@ ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=elastic KIBANA_PORT=5601 +# Using ElasticSearch Cloud Serverless, or not. +ELASTICSEARCH_USE_CLOUD=false +ELASTICSEARCH_CLOUD_URL=YOUR-ELASTICSEARCH_CLOUD_URL +ELASTICSEARCH_API_KEY=YOUR-ELASTICSEARCH_API_KEY + +ELASTICSEARCH_VERIFY_CERTS=False +ELASTICSEARCH_CA_CERTS= +ELASTICSEARCH_REQUEST_TIMEOUT=100000 +ELASTICSEARCH_RETRY_ON_TIMEOUT=True +ELASTICSEARCH_MAX_RETRIES=10 + # baidu vector configurations, only available when VECTOR_STORE is `baidu` BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 @@ -642,6 +653,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false # ------------------------------ # Knowledge Configuration diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 9e0f78eb07..690dccb1a8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -261,6 +261,14 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} KIBANA_PORT: ${KIBANA_PORT:-5601} + ELASTICSEARCH_USE_CLOUD: ${ELASTICSEARCH_USE_CLOUD:-false} + ELASTICSEARCH_CLOUD_URL: ${ELASTICSEARCH_CLOUD_URL:-YOUR-ELASTICSEARCH_CLOUD_URL} + ELASTICSEARCH_API_KEY: ${ELASTICSEARCH_API_KEY:-YOUR-ELASTICSEARCH_API_KEY} + ELASTICSEARCH_VERIFY_CERTS: ${ELASTICSEARCH_VERIFY_CERTS:-False} + ELASTICSEARCH_CA_CERTS: ${ELASTICSEARCH_CA_CERTS:-} + ELASTICSEARCH_REQUEST_TIMEOUT: ${ELASTICSEARCH_REQUEST_TIMEOUT:-100000} + ELASTICSEARCH_RETRY_ON_TIMEOUT: ${ELASTICSEARCH_RETRY_ON_TIMEOUT:-True} + ELASTICSEARCH_MAX_RETRIES: ${ELASTICSEARCH_MAX_RETRIES:-10} BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} @@ -304,6 +312,7 @@ x-shared-env: &shared-api-worker-env TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name} TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx} TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx} + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts new file mode 100644 index 0000000000..3bde095f4b --- /dev/null +++ b/web/__tests__/check-i18n.test.ts @@ -0,0 +1,566 @@ +import fs from 'node:fs' +import path from 'node:path' + +// Mock functions to simulate the check-i18n functionality +const vm = require('node:vm') +const transpile = require('typescript').transpile + +describe('check-i18n script functionality', () => { + const testDir = path.join(__dirname, '../i18n-test') + const testEnDir = path.join(testDir, 'en-US') + const testZhDir = path.join(testDir, 'zh-Hans') + + // Helper function that replicates the getKeysFromLanguage logic + async function getKeysFromLanguage(language: string, testPath = testDir): Promise { + return new Promise((resolve, reject) => { + const folderPath = path.resolve(testPath, language) + const allKeys: string[] = [] + + if (!fs.existsSync(folderPath)) { + resolve([]) + return + } + + fs.readdir(folderPath, (err, files) => { + if (err) { + reject(err) + return + } + + const translationFiles = files.filter(file => /\.(ts|js)$/.test(file)) + + translationFiles.forEach((file) => { + const filePath = path.join(folderPath, file) + const fileName = file.replace(/\.[^/.]+$/, '') + const camelCaseFileName = fileName.replace(/[-_](.)/g, (_, c) => + c.toUpperCase(), + ) + + try { + const content = fs.readFileSync(filePath, 'utf8') + const moduleExports = {} + const context = { + exports: moduleExports, + module: { exports: moduleExports }, + require, + console, + __filename: filePath, + __dirname: folderPath, + } + + vm.runInNewContext(transpile(content), context) + const translationObj = (context.module.exports as any).default || context.module.exports + + if (!translationObj || typeof translationObj !== 'object') + throw new Error(`Error parsing file: ${filePath}`) + + const nestedKeys: string[] = [] + const iterateKeys = (obj: any, prefix = '') => { + for (const key in obj) { + const nestedKey = prefix ? `${prefix}.${key}` : key + if (typeof obj[key] === 'object' && obj[key] !== null && !Array.isArray(obj[key])) { + // This is an object (but not array), recurse into it but don't add it as a key + iterateKeys(obj[key], nestedKey) + } + else { + // This is a leaf node (string, number, boolean, array, etc.), add it as a key + nestedKeys.push(nestedKey) + } + } + } + iterateKeys(translationObj) + + const fileKeys = nestedKeys.map(key => `${camelCaseFileName}.${key}`) + allKeys.push(...fileKeys) + } + catch (error) { + reject(error) + } + }) + resolve(allKeys) + }) + }) + } + + beforeEach(() => { + // Clean up and create test directories + if (fs.existsSync(testDir)) + fs.rmSync(testDir, { recursive: true }) + + fs.mkdirSync(testDir, { recursive: true }) + fs.mkdirSync(testEnDir, { recursive: true }) + fs.mkdirSync(testZhDir, { recursive: true }) + }) + + afterEach(() => { + // Clean up test files + if (fs.existsSync(testDir)) + fs.rmSync(testDir, { recursive: true }) + }) + + describe('Key extraction logic', () => { + it('should extract only leaf node keys, not intermediate objects', async () => { + const testContent = `const translation = { + simple: 'Simple Value', + nested: { + level1: 'Level 1 Value', + deep: { + level2: 'Level 2 Value' + } + }, + array: ['not extracted'], + number: 42, + boolean: true +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'test.ts'), testContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toEqual([ + 'test.simple', + 'test.nested.level1', + 'test.nested.deep.level2', + 'test.array', + 'test.number', + 'test.boolean', + ]) + + // Should not include intermediate object keys + expect(keys).not.toContain('test.nested') + expect(keys).not.toContain('test.nested.deep') + }) + + it('should handle camelCase file name conversion correctly', async () => { + const testContent = `const translation = { + key: 'value' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), testContent) + fs.writeFileSync(path.join(testEnDir, 'user_profile.ts'), testContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('appDebug.key') + expect(keys).toContain('userProfile.key') + }) + }) + + describe('Missing keys detection', () => { + it('should detect missing keys in target language', async () => { + const enContent = `const translation = { + common: { + save: 'Save', + cancel: 'Cancel', + delete: 'Delete' + }, + app: { + title: 'My App', + version: '1.0' + } +} + +export default translation +` + + const zhContent = `const translation = { + common: { + save: '保存', + cancel: '取消' + // missing 'delete' + }, + app: { + title: '我的应用' + // missing 'version' + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) + fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) + + expect(missingKeys).toContain('test.common.delete') + expect(missingKeys).toContain('test.app.version') + expect(missingKeys).toHaveLength(2) + }) + }) + + describe('Extra keys detection', () => { + it('should detect extra keys in target language', async () => { + const enContent = `const translation = { + common: { + save: 'Save', + cancel: 'Cancel' + } +} + +export default translation +` + + const zhContent = `const translation = { + common: { + save: '保存', + cancel: '取消', + delete: '删除', // extra key + extra: '额外的' // another extra key + }, + newSection: { + someKey: '某个值' // extra section + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) + fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) + + expect(extraKeys).toContain('test.common.delete') + expect(extraKeys).toContain('test.common.extra') + expect(extraKeys).toContain('test.newSection.someKey') + expect(extraKeys).toHaveLength(3) + }) + }) + + describe('File filtering logic', () => { + it('should filter keys by specific file correctly', async () => { + // Create multiple files + const file1Content = `const translation = { + button: 'Button', + text: 'Text' +} + +export default translation +` + + const file2Content = `const translation = { + title: 'Title', + description: 'Description' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'components.ts'), file1Content) + fs.writeFileSync(path.join(testEnDir, 'pages.ts'), file2Content) + fs.writeFileSync(path.join(testZhDir, 'components.ts'), file1Content) + fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content) + + const allEnKeys = await getKeysFromLanguage('en-US') + const allZhKeys = await getKeysFromLanguage('zh-Hans') + + // Test file filtering logic + const targetFile = 'components' + const filteredEnKeys = allEnKeys.filter(key => + key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), + ) + + expect(allEnKeys).toHaveLength(4) // 2 keys from each file + expect(filteredEnKeys).toHaveLength(2) // only components keys + expect(filteredEnKeys).toContain('components.button') + expect(filteredEnKeys).toContain('components.text') + expect(filteredEnKeys).not.toContain('pages.title') + expect(filteredEnKeys).not.toContain('pages.description') + }) + }) + + describe('Complex nested structure handling', () => { + it('should handle deeply nested objects correctly', async () => { + const complexContent = `const translation = { + level1: { + level2: { + level3: { + level4: { + deepValue: 'Deep Value' + }, + anotherValue: 'Another Value' + }, + simpleValue: 'Simple Value' + }, + directValue: 'Direct Value' + }, + rootValue: 'Root Value' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'complex.ts'), complexContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('complex.level1.level2.level3.level4.deepValue') + expect(keys).toContain('complex.level1.level2.level3.anotherValue') + expect(keys).toContain('complex.level1.level2.simpleValue') + expect(keys).toContain('complex.level1.directValue') + expect(keys).toContain('complex.rootValue') + + // Should not include intermediate objects + expect(keys).not.toContain('complex.level1') + expect(keys).not.toContain('complex.level1.level2') + expect(keys).not.toContain('complex.level1.level2.level3') + expect(keys).not.toContain('complex.level1.level2.level3.level4') + }) + }) + + describe('Edge cases', () => { + it('should handle empty objects', async () => { + const emptyContent = `const translation = { + empty: {}, + withValue: 'value' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'empty.ts'), emptyContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('empty.withValue') + expect(keys).not.toContain('empty.empty') + }) + + it('should handle special characters in keys', async () => { + const specialContent = `const translation = { + 'key-with-dash': 'value1', + 'key_with_underscore': 'value2', + 'key.with.dots': 'value3', + normalKey: 'value4' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'special.ts'), specialContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('special.key-with-dash') + expect(keys).toContain('special.key_with_underscore') + expect(keys).toContain('special.key.with.dots') + expect(keys).toContain('special.normalKey') + }) + + it('should handle different value types', async () => { + const typesContent = `const translation = { + stringValue: 'string', + numberValue: 42, + booleanValue: true, + nullValue: null, + undefinedValue: undefined, + arrayValue: ['array', 'values'], + objectValue: { + nested: 'nested value' + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'types.ts'), typesContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('types.stringValue') + expect(keys).toContain('types.numberValue') + expect(keys).toContain('types.booleanValue') + expect(keys).toContain('types.nullValue') + expect(keys).toContain('types.undefinedValue') + expect(keys).toContain('types.arrayValue') + expect(keys).toContain('types.objectValue.nested') + expect(keys).not.toContain('types.objectValue') + }) + }) + + describe('Real-world scenario tests', () => { + it('should handle app-debug structure like real files', async () => { + const appDebugEn = `const translation = { + pageTitle: { + line1: 'Prompt', + line2: 'Engineering' + }, + operation: { + applyConfig: 'Publish', + resetConfig: 'Reset', + debugConfig: 'Debug' + }, + generate: { + instruction: 'Instructions', + generate: 'Generate', + resTitle: 'Generated Prompt', + noDataLine1: 'Describe your use case on the left,', + noDataLine2: 'the orchestration preview will show here.' + } +} + +export default translation +` + + const appDebugZh = `const translation = { + pageTitle: { + line1: '提示词', + line2: '编排' + }, + operation: { + applyConfig: '发布', + resetConfig: '重置', + debugConfig: '调试' + }, + generate: { + instruction: '指令', + generate: '生成', + resTitle: '生成的提示词', + noData: '在左侧描述您的用例,编排预览将在此处显示。' // This is extra + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), appDebugEn) + fs.writeFileSync(path.join(testZhDir, 'app-debug.ts'), appDebugZh) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) + const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) + + expect(missingKeys).toContain('appDebug.generate.noDataLine1') + expect(missingKeys).toContain('appDebug.generate.noDataLine2') + expect(extraKeys).toContain('appDebug.generate.noData') + + expect(missingKeys).toHaveLength(2) + expect(extraKeys).toHaveLength(1) + }) + + it('should handle time structure with operation nested keys', async () => { + const timeEn = `const translation = { + months: { + January: 'January', + February: 'February' + }, + operation: { + now: 'Now', + ok: 'OK', + cancel: 'Cancel', + pickDate: 'Pick Date' + }, + title: { + pickTime: 'Pick Time' + }, + defaultPlaceholder: 'Pick a time...' +} + +export default translation +` + + const timeZh = `const translation = { + months: { + January: '一月', + February: '二月' + }, + operation: { + now: '此刻', + ok: '确定', + cancel: '取消', + pickDate: '选择日期' + }, + title: { + pickTime: '选择时间' + }, + pickDate: '选择日期', // This is extra - duplicates operation.pickDate + defaultPlaceholder: '请选择时间...' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'time.ts'), timeEn) + fs.writeFileSync(path.join(testZhDir, 'time.ts'), timeZh) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) + const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) + + expect(missingKeys).toHaveLength(0) // No missing keys + expect(extraKeys).toContain('time.pickDate') // Extra root-level pickDate + expect(extraKeys).toHaveLength(1) + + // Should have both keys available + expect(zhKeys).toContain('time.operation.pickDate') // Correct nested key + expect(zhKeys).toContain('time.pickDate') // Extra duplicate key + }) + }) + + describe('Statistics calculation', () => { + it('should calculate correct difference statistics', async () => { + const enContent = `const translation = { + key1: 'value1', + key2: 'value2', + key3: 'value3' +} + +export default translation +` + + const zhContentMissing = `const translation = { + key1: 'value1', + key2: 'value2' + // missing key3 +} + +export default translation +` + + const zhContentExtra = `const translation = { + key1: 'value1', + key2: 'value2', + key3: 'value3', + key4: 'extra', + key5: 'extra2' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'stats.ts'), enContent) + + // Test missing keys scenario + fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentMissing) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeysMissing = await getKeysFromLanguage('zh-Hans') + + expect(enKeys.length - zhKeysMissing.length).toBe(1) // +1 means 1 missing key + + // Test extra keys scenario + fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentExtra) + + const zhKeysExtra = await getKeysFromLanguage('zh-Hans') + + expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys + }) + }) +}) diff --git a/web/__tests__/plugin-tool-workflow-error.test.tsx b/web/__tests__/plugin-tool-workflow-error.test.tsx new file mode 100644 index 0000000000..370052bc80 --- /dev/null +++ b/web/__tests__/plugin-tool-workflow-error.test.tsx @@ -0,0 +1,207 @@ +/** + * Test cases to reproduce the plugin tool workflow error + * Issue: #23154 - Application error when loading plugin tools in workflow + * Root cause: split() operation called on null/undefined values + */ + +describe('Plugin Tool Workflow Error Reproduction', () => { + /** + * Mock function to simulate the problematic code in switch-plugin-version.tsx:29 + * const [pluginId] = uniqueIdentifier.split(':') + */ + const mockSwitchPluginVersionLogic = (uniqueIdentifier: string | null | undefined) => { + // This directly reproduces the problematic line from switch-plugin-version.tsx:29 + const [pluginId] = uniqueIdentifier!.split(':') + return pluginId + } + + /** + * Test case 1: Simulate null uniqueIdentifier + * This should reproduce the error mentioned in the issue + */ + it('should reproduce error when uniqueIdentifier is null', () => { + expect(() => { + mockSwitchPluginVersionLogic(null) + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test case 2: Simulate undefined uniqueIdentifier + */ + it('should reproduce error when uniqueIdentifier is undefined', () => { + expect(() => { + mockSwitchPluginVersionLogic(undefined) + }).toThrow('Cannot read properties of undefined (reading \'split\')') + }) + + /** + * Test case 3: Simulate empty string uniqueIdentifier + */ + it('should handle empty string uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('') + expect(result).toBe('') // Empty string split by ':' returns [''] + }).not.toThrow() + }) + + /** + * Test case 4: Simulate malformed uniqueIdentifier without colon separator + */ + it('should handle malformed uniqueIdentifier without colon separator', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('malformed-identifier-without-colon') + expect(result).toBe('malformed-identifier-without-colon') // No colon means full string returned + }).not.toThrow() + }) + + /** + * Test case 5: Simulate valid uniqueIdentifier + */ + it('should work correctly with valid uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('valid-plugin-id:1.0.0') + expect(result).toBe('valid-plugin-id') + }).not.toThrow() + }) +}) + +/** + * Test for the variable processing split error in use-single-run-form-params + */ +describe('Variable Processing Split Error', () => { + /** + * Mock function to simulate the problematic code in use-single-run-form-params.ts:91 + * const getDependentVars = () => { + * return varInputs.map(item => item.variable.slice(1, -1).split('.')) + * } + */ + const mockGetDependentVars = (varInputs: Array<{ variable: string | null | undefined }>) => { + return varInputs.map((item) => { + // Guard against null/undefined variable to prevent app crash + if (!item.variable || typeof item.variable !== 'string') + return [] + + return item.variable.slice(1, -1).split('.') + }).filter(arr => arr.length > 0) // Filter out empty arrays + } + + /** + * Test case 1: Variable processing with null variable + */ + it('should handle null variable safely', () => { + const varInputs = [{ variable: null }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // null variables are filtered out + }) + + /** + * Test case 2: Variable processing with undefined variable + */ + it('should handle undefined variable safely', () => { + const varInputs = [{ variable: undefined }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // undefined variables are filtered out + }) + + /** + * Test case 3: Variable processing with empty string + */ + it('should handle empty string variable', () => { + const varInputs = [{ variable: '' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // Empty string is filtered out, so result is empty array + }) + + /** + * Test case 4: Variable processing with valid variable format + */ + it('should work correctly with valid variable format', () => { + const varInputs = [{ variable: '{{workflow.node.output}}' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result[0]).toEqual(['{workflow', 'node', 'output}']) + }) +}) + +/** + * Integration test to simulate the complete workflow scenario + */ +describe('Plugin Tool Workflow Integration', () => { + /** + * Simulate the scenario where plugin metadata is incomplete or corrupted + * This can happen when: + * 1. Plugin is being loaded from marketplace but metadata request fails + * 2. Plugin configuration is corrupted in database + * 3. Network issues during plugin loading + */ + it('should reproduce the client-side exception scenario', () => { + // Mock incomplete plugin data that could cause the error + const incompletePluginData = { + // Missing or null uniqueIdentifier + uniqueIdentifier: null, + meta: null, + minimum_dify_version: undefined, + } + + // This simulates the error path that leads to the white screen + expect(() => { + // Simulate the code path in switch-plugin-version.tsx:29 + // The actual problematic code doesn't use optional chaining + const _pluginId = (incompletePluginData.uniqueIdentifier as any).split(':')[0] + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test the scenario mentioned in the issue where plugin tools are loaded in workflow + */ + it('should simulate plugin tool loading in workflow context', () => { + // Mock the workflow context where plugin tools are being loaded + const workflowPluginTools = [ + { + provider_name: 'test-plugin', + uniqueIdentifier: null, // This is the problematic case + tool_name: 'test-tool', + }, + { + provider_name: 'valid-plugin', + uniqueIdentifier: 'valid-plugin:1.0.0', + tool_name: 'valid-tool', + }, + ] + + // Process each plugin tool + workflowPluginTools.forEach((tool, _index) => { + if (tool.uniqueIdentifier === null) { + // This reproduces the exact error scenario + expect(() => { + const _pluginId = (tool.uniqueIdentifier as any).split(':')[0] + }).toThrow() + } + else { + // Valid tools should work fine + expect(() => { + const _pluginId = tool.uniqueIdentifier.split(':')[0] + }).not.toThrow() + } + }) + }) +}) diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx new file mode 100644 index 0000000000..0843122ab4 --- /dev/null +++ b/web/__tests__/workflow-parallel-limit.test.tsx @@ -0,0 +1,301 @@ +/** + * MAX_PARALLEL_LIMIT Configuration Bug Test + * + * This test reproduces and verifies the fix for issue #23083: + * MAX_PARALLEL_LIMIT environment variable does not take effect in iteration panel + */ + +import { render, screen } from '@testing-library/react' +import React from 'react' + +// Mock environment variables before importing constants +const originalEnv = process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT + +// Test with different environment values +function setupEnvironment(value?: string) { + if (value) + process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = value + else + delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT + + // Clear module cache to force re-evaluation + jest.resetModules() +} + +function restoreEnvironment() { + if (originalEnv) + process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = originalEnv + else + delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT + + jest.resetModules() +} + +// Mock i18next with proper implementation +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + if (key.includes('MaxParallelismTitle')) return 'Max Parallelism' + if (key.includes('MaxParallelismDesc')) return 'Maximum number of parallel executions' + if (key.includes('parallelMode')) return 'Parallel Mode' + if (key.includes('parallelPanelDesc')) return 'Enable parallel execution' + if (key.includes('errorResponseMethod')) return 'Error Response Method' + return key + }, + }), + initReactI18next: { + type: '3rdParty', + init: jest.fn(), + }, +})) + +// Mock i18next module completely to prevent initialization issues +jest.mock('i18next', () => ({ + use: jest.fn().mockReturnThis(), + init: jest.fn().mockReturnThis(), + t: jest.fn(key => key), + isInitialized: true, +})) + +// Mock the useConfig hook +jest.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ + __esModule: true, + default: () => ({ + inputs: { + is_parallel: true, + parallel_nums: 5, + error_handle_mode: 'terminated', + }, + changeParallel: jest.fn(), + changeParallelNums: jest.fn(), + changeErrorHandleMode: jest.fn(), + }), +})) + +// Mock other components +jest.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => { + return function MockVarReferencePicker() { + return
VarReferencePicker
+ } +}) + +jest.mock('@/app/components/workflow/nodes/_base/components/split', () => { + return function MockSplit() { + return
Split
+ } +}) + +jest.mock('@/app/components/workflow/nodes/_base/components/field', () => { + return function MockField({ title, children }: { title: string, children: React.ReactNode }) { + return ( +
+ + {children} +
+ ) + } +}) + +jest.mock('@/app/components/base/switch', () => { + return function MockSwitch({ defaultValue }: { defaultValue: boolean }) { + return + } +}) + +jest.mock('@/app/components/base/select', () => { + return function MockSelect() { + return + } +}) + +// Use defaultValue to avoid controlled input warnings +jest.mock('@/app/components/base/slider', () => { + return function MockSlider({ value, max, min }: { value: number, max: number, min: number }) { + return ( + + ) + } +}) + +// Use defaultValue to avoid controlled input warnings +jest.mock('@/app/components/base/input', () => { + return function MockInput({ type, max, min, value }: { type: string, max: number, min: number, value: number }) { + return ( + + ) + } +}) + +describe('MAX_PARALLEL_LIMIT Configuration Bug', () => { + const mockNodeData = { + id: 'test-iteration-node', + type: 'iteration' as const, + data: { + title: 'Test Iteration', + desc: 'Test iteration node', + iterator_selector: ['test'], + output_selector: ['output'], + is_parallel: true, + parallel_nums: 5, + error_handle_mode: 'terminated' as const, + }, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + afterEach(() => { + restoreEnvironment() + }) + + afterAll(() => { + restoreEnvironment() + }) + + describe('Environment Variable Parsing', () => { + it('should parse MAX_PARALLEL_LIMIT from NEXT_PUBLIC_MAX_PARALLEL_LIMIT environment variable', () => { + setupEnvironment('25') + const { MAX_PARALLEL_LIMIT } = require('@/config') + expect(MAX_PARALLEL_LIMIT).toBe(25) + }) + + it('should fallback to default when environment variable is not set', () => { + setupEnvironment() // No environment variable + const { MAX_PARALLEL_LIMIT } = require('@/config') + expect(MAX_PARALLEL_LIMIT).toBe(10) + }) + + it('should handle invalid environment variable values', () => { + setupEnvironment('invalid') + const { MAX_PARALLEL_LIMIT } = require('@/config') + + // Should fall back to default when parsing fails + expect(MAX_PARALLEL_LIMIT).toBe(10) + }) + + it('should handle empty environment variable', () => { + setupEnvironment('') + const { MAX_PARALLEL_LIMIT } = require('@/config') + + // Should fall back to default when empty + expect(MAX_PARALLEL_LIMIT).toBe(10) + }) + + // Edge cases for boundary values + it('should clamp MAX_PARALLEL_LIMIT to MIN when env is 0 or negative', () => { + setupEnvironment('0') + let { MAX_PARALLEL_LIMIT } = require('@/config') + expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default + + setupEnvironment('-5') + ;({ MAX_PARALLEL_LIMIT } = require('@/config')) + expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default + }) + + it('should handle float numbers by parseInt behavior', () => { + setupEnvironment('12.7') + const { MAX_PARALLEL_LIMIT } = require('@/config') + // parseInt truncates to integer + expect(MAX_PARALLEL_LIMIT).toBe(12) + }) + }) + + describe('UI Component Integration (Main Fix Verification)', () => { + it('should render iteration panel with environment-configured max value', () => { + // Set environment variable to a different value + setupEnvironment('30') + + // Import Panel after setting environment + const Panel = require('@/app/components/workflow/nodes/iteration/panel').default + const { MAX_PARALLEL_LIMIT } = require('@/config') + + render( + , + ) + + // Behavior-focused assertion: UI max should equal MAX_PARALLEL_LIMIT + const numberInput = screen.getByTestId('number-input') + expect(numberInput).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT)) + + const slider = screen.getByTestId('slider') + expect(slider).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT)) + + // Verify the actual values + expect(MAX_PARALLEL_LIMIT).toBe(30) + expect(numberInput.getAttribute('data-max')).toBe('30') + expect(slider.getAttribute('data-max')).toBe('30') + }) + + it('should maintain UI consistency with different environment values', () => { + setupEnvironment('15') + const Panel = require('@/app/components/workflow/nodes/iteration/panel').default + const { MAX_PARALLEL_LIMIT } = require('@/config') + + render( + , + ) + + // Both input and slider should use the same max value from MAX_PARALLEL_LIMIT + const numberInput = screen.getByTestId('number-input') + const slider = screen.getByTestId('slider') + + expect(numberInput.getAttribute('data-max')).toBe(slider.getAttribute('data-max')) + expect(numberInput.getAttribute('data-max')).toBe(String(MAX_PARALLEL_LIMIT)) + }) + }) + + describe('Legacy Constant Verification (For Transition Period)', () => { + // Marked as transition/deprecation tests + it('should maintain MAX_ITERATION_PARALLEL_NUM for backward compatibility', () => { + const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + expect(typeof MAX_ITERATION_PARALLEL_NUM).toBe('number') + expect(MAX_ITERATION_PARALLEL_NUM).toBe(10) // Hardcoded legacy value + }) + + it('should demonstrate MAX_PARALLEL_LIMIT vs legacy constant difference', () => { + setupEnvironment('50') + const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + + // MAX_PARALLEL_LIMIT is configurable, MAX_ITERATION_PARALLEL_NUM is not + expect(MAX_PARALLEL_LIMIT).toBe(50) + expect(MAX_ITERATION_PARALLEL_NUM).toBe(10) + expect(MAX_PARALLEL_LIMIT).not.toBe(MAX_ITERATION_PARALLEL_NUM) + }) + }) + + describe('Constants Validation', () => { + it('should validate that required constants exist and have correct types', () => { + const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MIN_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + expect(typeof MAX_PARALLEL_LIMIT).toBe('number') + expect(typeof MIN_ITERATION_PARALLEL_NUM).toBe('number') + expect(MAX_PARALLEL_LIMIT).toBeGreaterThanOrEqual(MIN_ITERATION_PARALLEL_NUM) + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 58c9f7e5ca..c04d79d2f2 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -271,16 +271,17 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx - { - expand && ( -
-
-
{appDetail.name}
-
-
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
-
- ) - } +
+
+
{appDetail.name}
+
+
{appDetail.mode === 'advanced-chat' ? t('app.types.advanced') : appDetail.mode === 'agent-chat' ? t('app.types.agent') : appDetail.mode === 'chat' ? t('app.types.chatbot') : appDetail.mode === 'completion' ? t('app.types.completion') : t('app.types.workflow')}
+
)} diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index b6bfc0e9ac..cf32339b8a 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -124,10 +124,7 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati { !isMobile && (
({ + useSelectedLayoutSegment: () => 'overview', +})) + +// Mock Next.js Link component +jest.mock('next/link', () => { + return function MockLink({ children, href, className, title }: any) { + return ( + + {children} + + ) + } +}) + +// Mock RemixIcon components +const MockIcon = ({ className }: { className?: string }) => ( + +) + +describe('NavLink Text Animation Issues', () => { + const mockProps: NavLinkProps = { + name: 'Orchestrate', + href: '/app/123/workflow', + iconMap: { + selected: MockIcon, + normal: MockIcon, + }, + } + + beforeEach(() => { + // Mock getComputedStyle for transition testing + Object.defineProperty(window, 'getComputedStyle', { + value: jest.fn((element) => { + const isExpanded = element.getAttribute('data-mode') === 'expand' + return { + transition: 'all 0.3s ease', + opacity: isExpanded ? '1' : '0', + width: isExpanded ? 'auto' : '0px', + overflow: 'hidden', + paddingLeft: isExpanded ? '12px' : '10px', // px-3 vs px-2.5 + paddingRight: isExpanded ? '12px' : '10px', + } + }), + writable: true, + }) + }) + + describe('Text Squeeze Animation Issue', () => { + it('should show text squeeze effect when switching from collapse to expand', async () => { + const { rerender } = render() + + // In collapse mode, text should be in DOM but hidden via CSS + const textElement = screen.getByText('Orchestrate') + expect(textElement).toBeInTheDocument() + expect(textElement).toHaveClass('opacity-0') + expect(textElement).toHaveClass('w-0') + expect(textElement).toHaveClass('overflow-hidden') + + // Icon should still be present + expect(screen.getByTestId('nav-icon')).toBeInTheDocument() + + // Check padding in collapse mode + const linkElement = screen.getByTestId('nav-link') + expect(linkElement).toHaveClass('px-2.5') + + // Switch to expand mode - this is where the squeeze effect occurs + rerender() + + // Text should now appear + expect(screen.getByText('Orchestrate')).toBeInTheDocument() + + // Check padding change - this contributes to the squeeze effect + expect(linkElement).toHaveClass('px-3') + + // The bug: text appears abruptly without smooth transition + // This test documents the current behavior that causes the squeeze effect + const expandedTextElement = screen.getByText('Orchestrate') + expect(expandedTextElement).toBeInTheDocument() + + // In a properly animated version, we would expect: + // - Opacity transition from 0 to 1 + // - Width transition from 0 to auto + // - No layout shift from padding changes + }) + + it('should maintain icon position consistency during text appearance', () => { + const { rerender } = render() + + const iconElement = screen.getByTestId('nav-icon') + const initialIconClasses = iconElement.className + + // Icon should have mr-0 in collapse mode + expect(iconElement).toHaveClass('mr-0') + + rerender() + + const expandedIconClasses = iconElement.className + + // Icon should have mr-2 in expand mode - this shift contributes to the squeeze effect + expect(iconElement).toHaveClass('mr-2') + + console.log('Collapsed icon classes:', initialIconClasses) + console.log('Expanded icon classes:', expandedIconClasses) + + // This margin change causes the icon to shift when text appears + }) + + it('should document the abrupt text rendering issue', () => { + const { rerender } = render() + + // Text is present in DOM but hidden via CSS classes + const collapsedText = screen.getByText('Orchestrate') + expect(collapsedText).toBeInTheDocument() + expect(collapsedText).toHaveClass('opacity-0') + expect(collapsedText).toHaveClass('pointer-events-none') + + rerender() + + // Text suddenly appears in DOM - no transition + expect(screen.getByText('Orchestrate')).toBeInTheDocument() + + // The issue: {mode === 'expand' && name} causes abrupt show/hide + // instead of smooth opacity/width transition + }) + }) + + describe('Layout Shift Issues', () => { + it('should detect padding differences causing layout shifts', () => { + const { rerender } = render() + + const linkElement = screen.getByTestId('nav-link') + + // Collapsed state padding + expect(linkElement).toHaveClass('px-2.5') + + rerender() + + // Expanded state padding - different value causes layout shift + expect(linkElement).toHaveClass('px-3') + + // This 2px difference (10px vs 12px) contributes to the squeeze effect + }) + + it('should detect icon margin changes causing shifts', () => { + const { rerender } = render() + + const iconElement = screen.getByTestId('nav-icon') + + // Collapsed: no right margin + expect(iconElement).toHaveClass('mr-0') + + rerender() + + // Expanded: 8px right margin (mr-2) + expect(iconElement).toHaveClass('mr-2') + + // This sudden margin appearance causes the squeeze effect + }) + }) + + describe('Active State Handling', () => { + it('should handle active state correctly in both modes', () => { + // Test non-active state + const { rerender } = render() + + let linkElement = screen.getByTestId('nav-link') + expect(linkElement).not.toHaveClass('bg-state-accent-active') + + // Test with active state (when href matches current segment) + const activeProps = { + ...mockProps, + href: '/app/123/overview', // matches mocked segment + } + + rerender() + + linkElement = screen.getByTestId('nav-link') + expect(linkElement).toHaveClass('bg-state-accent-active') + }) + }) +}) diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index 295b553b04..4607f7b693 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -44,20 +44,29 @@ export default function NavLink({ key={name} href={href} className={classNames( - isActive ? 'bg-state-accent-active text-text-accent font-semibold' : 'text-components-menu-item-text hover:bg-state-base-hover hover:text-components-menu-item-text-hover', - 'group flex items-center h-9 rounded-md py-2 text-sm font-normal', + isActive ? 'bg-state-accent-active font-semibold text-text-accent' : 'text-components-menu-item-text hover:bg-state-base-hover hover:text-components-menu-item-text-hover', + 'group flex h-9 items-center rounded-md py-2 text-sm font-normal', mode === 'expand' ? 'px-3' : 'px-2.5', )} title={mode === 'collapse' ? name : ''} >