From df938a4543ce5a527984ef88fe27326f1a1c1e45 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 12 Jan 2026 15:07:53 +0800 Subject: [PATCH 01/29] ci: add HITL test env deployment action (#30846) --- .github/workflows/deploy-hitl.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/deploy-hitl.yml diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml new file mode 100644 index 0000000000..8144ba4f08 --- /dev/null +++ b/.github/workflows/deploy-hitl.yml @@ -0,0 +1,29 @@ +name: Deploy HITL + +on: + workflow_run: + workflows: ["Build and Push API & Web"] + branches: + - "feat/hitl-frontend" + - "feat/hitl-backend" + types: + - completed + +jobs: + deploy: + runs-on: ubuntu-latest + if: | + github.event.workflow_run.conclusion == 'success' && + ( + github.event.workflow_run.head_branch == 'feat/hitl-frontend' || + github.event.workflow_run.head_branch == 'feat/hitl-backend' + ) + steps: + - name: Deploy to server + uses: appleboy/ssh-action@v0.1.8 + with: + host: ${{ secrets.HITL_SSH_HOST }} + username: ${{ secrets.SSH_USER }} + key: ${{ secrets.SSH_PRIVATE_KEY }} + script: | + ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }} From 00698e41b7c6bbded6869dbfaa05bf35350039e3 Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Mon, 12 Jan 2026 15:33:20 +0800 Subject: [PATCH 02/29] build: limit esbuild, glob, docker base version to avoid cve (#30848) --- web/Dockerfile | 2 +- web/package.json | 4 +- web/pnpm-lock.yaml | 626 +++++++++++++-------------------------------- 3 files changed, 187 insertions(+), 445 deletions(-) diff --git a/web/Dockerfile b/web/Dockerfile index 8697793145..9e08910a77 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -1,5 +1,5 @@ # base image -FROM node:22-alpine3.21 AS base +FROM node:22.21.1-alpine3.23 AS base LABEL maintainer="takatost@gmail.com" # if you located in China, you can use aliyun mirror to speed up diff --git a/web/package.json b/web/package.json index 7537b942fb..4019e49cd9 100644 --- a/web/package.json +++ b/web/package.json @@ -236,7 +236,8 @@ "brace-expansion@<2.0.2": "2.0.2", "devalue@<5.3.2": "5.3.2", "es-iterator-helpers": "npm:@nolyfill/es-iterator-helpers@^1", - "esbuild@<0.25.0": "0.25.0", + "esbuild@<0.27.2": "0.27.2", + "glob@>=10.2.0,<10.5.0": "11.1.0", "hasown": "npm:@nolyfill/hasown@^1", "is-arguments": "npm:@nolyfill/is-arguments@^1", "is-core-module": "npm:@nolyfill/is-core-module@^1", @@ -278,7 +279,6 @@ "@types/react-dom": "~19.2.3", "brace-expansion": "~2.0", "canvas": "^3.2.0", - "esbuild": "~0.25.0", "pbkdf2": "~3.1.3", "prismjs": "~1.30", "string-width": "~4.2.3" diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index f39f7503d9..853c366025 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -5,23 +5,17 @@ settings: excludeLinksFromLockfile: false overrides: + '@eslint/plugin-kit': ~0.3 '@types/react': ~19.2.7 '@types/react-dom': ~19.2.3 - string-width: ~4.2.3 - '@eslint/plugin-kit': ~0.3 + brace-expansion: ~2.0 canvas: ^3.2.0 - esbuild: ~0.25.0 pbkdf2: ~3.1.3 prismjs: ~1.30 - brace-expansion: ~2.0 - '@monaco-editor/loader': 1.5.0 + string-width: ~4.2.3 '@eslint/plugin-kit@<0.3.4': 0.3.4 - brace-expansion@<2.0.2: 2.0.2 - devalue@<5.3.2: 5.3.2 - esbuild@<0.25.0: 0.25.0 - pbkdf2@<3.1.3: 3.1.3 - prismjs@<1.30.0: 1.30.0 - vite@<6.4.1: 6.4.1 + '@monaco-editor/loader': 1.5.0 + '@nolyfill/safe-buffer': npm:safe-buffer@^5.2.1 array-includes: npm:@nolyfill/array-includes@^1 array.prototype.findlast: npm:@nolyfill/array.prototype.findlast@^1 array.prototype.findlastindex: npm:@nolyfill/array.prototype.findlastindex@^1 @@ -29,7 +23,11 @@ overrides: array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1 array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1 assert: npm:@nolyfill/assert@^1 + brace-expansion@<2.0.2: 2.0.2 + devalue@<5.3.2: 5.3.2 es-iterator-helpers: npm:@nolyfill/es-iterator-helpers@^1 + esbuild@<0.27.2: 0.27.2 + glob@>=10.2.0,<10.5.0: 11.1.0 hasown: npm:@nolyfill/hasown@^1 is-arguments: npm:@nolyfill/is-arguments@^1 is-core-module: npm:@nolyfill/is-core-module@^1 @@ -41,8 +39,9 @@ overrides: object.fromentries: npm:@nolyfill/object.fromentries@^1 object.groupby: npm:@nolyfill/object.groupby@^1 object.values: npm:@nolyfill/object.values@^1 + pbkdf2@<3.1.3: 3.1.3 + prismjs@<1.30.0: 1.30.0 safe-buffer: ^5.2.1 - '@nolyfill/safe-buffer': npm:safe-buffer@^5.2.1 safe-regex-test: npm:@nolyfill/safe-regex-test@^1 safer-buffer: npm:@nolyfill/safer-buffer@^1 side-channel: npm:@nolyfill/side-channel@^1 @@ -51,6 +50,7 @@ overrides: string.prototype.repeat: npm:@nolyfill/string.prototype.repeat@^1 string.prototype.trimend: npm:@nolyfill/string.prototype.trimend@^1 typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1 + vite@<6.4.1: 6.4.1 which-typed-array: npm:@nolyfill/which-typed-array@^1 importers: @@ -360,7 +360,7 @@ importers: version: 2.3.13(eslint@9.39.2(jiti@1.21.7))(typescript@5.9.3) '@mdx-js/loader': specifier: ^3.1.1 - version: 3.1.1(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + version: 3.1.1(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) '@mdx-js/react': specifier: ^3.1.1 version: 3.1.1(@types/react@19.2.7)(react@19.2.3) @@ -372,7 +372,7 @@ importers: version: 15.5.9 '@next/mdx': specifier: 15.5.9 - version: 15.5.9(@mdx-js/loader@3.1.1(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.7)(react@19.2.3)) + version: 15.5.9(@mdx-js/loader@3.1.1(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.7)(react@19.2.3)) '@rgrove/parse-xml': specifier: ^4.2.0 version: 4.2.0 @@ -393,7 +393,7 @@ importers: version: 9.1.13(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2))) '@storybook/nextjs': specifier: 9.1.13 - version: 9.1.13(esbuild@0.25.0)(next@15.5.9(@babel/core@7.28.5)(@playwright/test@1.57.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(type-fest@4.2.0)(typescript@5.9.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + version: 9.1.13(esbuild@0.27.2)(next@15.5.9(@babel/core@7.28.5)(@playwright/test@1.57.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(type-fest@4.2.0)(typescript@5.9.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': specifier: 9.1.17 version: 9.1.17(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3) @@ -1411,308 +1411,158 @@ packages: resolution: {integrity: sha512-Q9hjxWI5xBM+qW2enxfe8wDKdFWMfd0Z29k5ZJnuBqD/CasY5Zryj09aCA6owbGATWz+39p5uIdaHXpopOcG8g==} engines: {node: '>=10'} - '@esbuild/aix-ppc64@0.25.0': - resolution: {integrity: sha512-O7vun9Sf8DFjH2UtqK8Ku3LkquL9SZL8OLY1T5NZkA34+wG3OQF7cl4Ql8vdNzM6fzBbYfLaiRLIOZ+2FOCgBQ==} + '@esbuild/aix-ppc64@0.27.2': + resolution: {integrity: sha512-GZMB+a0mOMZs4MpDbj8RJp4cw+w1WV5NYD6xzgvzUJ5Ek2jerwfO2eADyI6ExDSUED+1X8aMbegahsJi+8mgpw==} engines: {node: '>=18'} cpu: [ppc64] os: [aix] - '@esbuild/aix-ppc64@0.25.12': - resolution: {integrity: sha512-Hhmwd6CInZ3dwpuGTF8fJG6yoWmsToE+vYgD4nytZVxcu1ulHpUQRAB1UJ8+N1Am3Mz4+xOByoQoSZf4D+CpkA==} - engines: {node: '>=18'} - cpu: [ppc64] - os: [aix] - - '@esbuild/android-arm64@0.25.0': - resolution: {integrity: sha512-grvv8WncGjDSyUBjN9yHXNt+cq0snxXbDxy5pJtzMKGmmpPxeAmAhWxXI+01lU5rwZomDgD3kJwulEnhTRUd6g==} + '@esbuild/android-arm64@0.27.2': + resolution: {integrity: sha512-pvz8ZZ7ot/RBphf8fv60ljmaoydPU12VuXHImtAs0XhLLw+EXBi2BLe3OYSBslR4rryHvweW5gmkKFwTiFy6KA==} engines: {node: '>=18'} cpu: [arm64] os: [android] - '@esbuild/android-arm64@0.25.12': - resolution: {integrity: sha512-6AAmLG7zwD1Z159jCKPvAxZd4y/VTO0VkprYy+3N2FtJ8+BQWFXU+OxARIwA46c5tdD9SsKGZ/1ocqBS/gAKHg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [android] - - '@esbuild/android-arm@0.25.0': - resolution: {integrity: sha512-PTyWCYYiU0+1eJKmw21lWtC+d08JDZPQ5g+kFyxP0V+es6VPPSUhM6zk8iImp2jbV6GwjX4pap0JFbUQN65X1g==} + '@esbuild/android-arm@0.27.2': + resolution: {integrity: sha512-DVNI8jlPa7Ujbr1yjU2PfUSRtAUZPG9I1RwW4F4xFB1Imiu2on0ADiI/c3td+KmDtVKNbi+nffGDQMfcIMkwIA==} engines: {node: '>=18'} cpu: [arm] os: [android] - '@esbuild/android-arm@0.25.12': - resolution: {integrity: sha512-VJ+sKvNA/GE7Ccacc9Cha7bpS8nyzVv0jdVgwNDaR4gDMC/2TTRc33Ip8qrNYUcpkOHUT5OZ0bUcNNVZQ9RLlg==} - engines: {node: '>=18'} - cpu: [arm] - os: [android] - - '@esbuild/android-x64@0.25.0': - resolution: {integrity: sha512-m/ix7SfKG5buCnxasr52+LI78SQ+wgdENi9CqyCXwjVR2X4Jkz+BpC3le3AoBPYTC9NHklwngVXvbJ9/Akhrfg==} + '@esbuild/android-x64@0.27.2': + resolution: {integrity: sha512-z8Ank4Byh4TJJOh4wpz8g2vDy75zFL0TlZlkUkEwYXuPSgX8yzep596n6mT7905kA9uHZsf/o2OJZubl2l3M7A==} engines: {node: '>=18'} cpu: [x64] os: [android] - '@esbuild/android-x64@0.25.12': - resolution: {integrity: sha512-5jbb+2hhDHx5phYR2By8GTWEzn6I9UqR11Kwf22iKbNpYrsmRB18aX/9ivc5cabcUiAT/wM+YIZ6SG9QO6a8kg==} - engines: {node: '>=18'} - cpu: [x64] - os: [android] - - '@esbuild/darwin-arm64@0.25.0': - resolution: {integrity: sha512-mVwdUb5SRkPayVadIOI78K7aAnPamoeFR2bT5nszFUZ9P8UpK4ratOdYbZZXYSqPKMHfS1wdHCJk1P1EZpRdvw==} + '@esbuild/darwin-arm64@0.27.2': + resolution: {integrity: sha512-davCD2Zc80nzDVRwXTcQP/28fiJbcOwvdolL0sOiOsbwBa72kegmVU0Wrh1MYrbuCL98Omp5dVhQFWRKR2ZAlg==} engines: {node: '>=18'} cpu: [arm64] os: [darwin] - '@esbuild/darwin-arm64@0.25.12': - resolution: {integrity: sha512-N3zl+lxHCifgIlcMUP5016ESkeQjLj/959RxxNYIthIg+CQHInujFuXeWbWMgnTo4cp5XVHqFPmpyu9J65C1Yg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [darwin] - - '@esbuild/darwin-x64@0.25.0': - resolution: {integrity: sha512-DgDaYsPWFTS4S3nWpFcMn/33ZZwAAeAFKNHNa1QN0rI4pUjgqf0f7ONmXf6d22tqTY+H9FNdgeaAa+YIFUn2Rg==} + '@esbuild/darwin-x64@0.27.2': + resolution: {integrity: sha512-ZxtijOmlQCBWGwbVmwOF/UCzuGIbUkqB1faQRf5akQmxRJ1ujusWsb3CVfk/9iZKr2L5SMU5wPBi1UWbvL+VQA==} engines: {node: '>=18'} cpu: [x64] os: [darwin] - '@esbuild/darwin-x64@0.25.12': - resolution: {integrity: sha512-HQ9ka4Kx21qHXwtlTUVbKJOAnmG1ipXhdWTmNXiPzPfWKpXqASVcWdnf2bnL73wgjNrFXAa3yYvBSd9pzfEIpA==} - engines: {node: '>=18'} - cpu: [x64] - os: [darwin] - - '@esbuild/freebsd-arm64@0.25.0': - resolution: {integrity: sha512-VN4ocxy6dxefN1MepBx/iD1dH5K8qNtNe227I0mnTRjry8tj5MRk4zprLEdG8WPyAPb93/e4pSgi1SoHdgOa4w==} + '@esbuild/freebsd-arm64@0.27.2': + resolution: {integrity: sha512-lS/9CN+rgqQ9czogxlMcBMGd+l8Q3Nj1MFQwBZJyoEKI50XGxwuzznYdwcav6lpOGv5BqaZXqvBSiB/kJ5op+g==} engines: {node: '>=18'} cpu: [arm64] os: [freebsd] - '@esbuild/freebsd-arm64@0.25.12': - resolution: {integrity: sha512-gA0Bx759+7Jve03K1S0vkOu5Lg/85dou3EseOGUes8flVOGxbhDDh/iZaoek11Y8mtyKPGF3vP8XhnkDEAmzeg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [freebsd] - - '@esbuild/freebsd-x64@0.25.0': - resolution: {integrity: sha512-mrSgt7lCh07FY+hDD1TxiTyIHyttn6vnjesnPoVDNmDfOmggTLXRv8Id5fNZey1gl/V2dyVK1VXXqVsQIiAk+A==} + '@esbuild/freebsd-x64@0.27.2': + resolution: {integrity: sha512-tAfqtNYb4YgPnJlEFu4c212HYjQWSO/w/h/lQaBK7RbwGIkBOuNKQI9tqWzx7Wtp7bTPaGC6MJvWI608P3wXYA==} engines: {node: '>=18'} cpu: [x64] os: [freebsd] - '@esbuild/freebsd-x64@0.25.12': - resolution: {integrity: sha512-TGbO26Yw2xsHzxtbVFGEXBFH0FRAP7gtcPE7P5yP7wGy7cXK2oO7RyOhL5NLiqTlBh47XhmIUXuGciXEqYFfBQ==} - engines: {node: '>=18'} - cpu: [x64] - os: [freebsd] - - '@esbuild/linux-arm64@0.25.0': - resolution: {integrity: sha512-9QAQjTWNDM/Vk2bgBl17yWuZxZNQIF0OUUuPZRKoDtqF2k4EtYbpyiG5/Dk7nqeK6kIJWPYldkOcBqjXjrUlmg==} + '@esbuild/linux-arm64@0.27.2': + resolution: {integrity: sha512-hYxN8pr66NsCCiRFkHUAsxylNOcAQaxSSkHMMjcpx0si13t1LHFphxJZUiGwojB1a/Hd5OiPIqDdXONia6bhTw==} engines: {node: '>=18'} cpu: [arm64] os: [linux] - '@esbuild/linux-arm64@0.25.12': - resolution: {integrity: sha512-8bwX7a8FghIgrupcxb4aUmYDLp8pX06rGh5HqDT7bB+8Rdells6mHvrFHHW2JAOPZUbnjUpKTLg6ECyzvas2AQ==} - engines: {node: '>=18'} - cpu: [arm64] - os: [linux] - - '@esbuild/linux-arm@0.25.0': - resolution: {integrity: sha512-vkB3IYj2IDo3g9xX7HqhPYxVkNQe8qTK55fraQyTzTX/fxaDtXiEnavv9geOsonh2Fd2RMB+i5cbhu2zMNWJwg==} + '@esbuild/linux-arm@0.27.2': + resolution: {integrity: sha512-vWfq4GaIMP9AIe4yj1ZUW18RDhx6EPQKjwe7n8BbIecFtCQG4CfHGaHuh7fdfq+y3LIA2vGS/o9ZBGVxIDi9hw==} engines: {node: '>=18'} cpu: [arm] os: [linux] - '@esbuild/linux-arm@0.25.12': - resolution: {integrity: sha512-lPDGyC1JPDou8kGcywY0YILzWlhhnRjdof3UlcoqYmS9El818LLfJJc3PXXgZHrHCAKs/Z2SeZtDJr5MrkxtOw==} - engines: {node: '>=18'} - cpu: [arm] - os: [linux] - - '@esbuild/linux-ia32@0.25.0': - resolution: {integrity: sha512-43ET5bHbphBegyeqLb7I1eYn2P/JYGNmzzdidq/w0T8E2SsYL1U6un2NFROFRg1JZLTzdCoRomg8Rvf9M6W6Gg==} + '@esbuild/linux-ia32@0.27.2': + resolution: {integrity: sha512-MJt5BRRSScPDwG2hLelYhAAKh9imjHK5+NE/tvnRLbIqUWa+0E9N4WNMjmp/kXXPHZGqPLxggwVhz7QP8CTR8w==} engines: {node: '>=18'} cpu: [ia32] os: [linux] - '@esbuild/linux-ia32@0.25.12': - resolution: {integrity: sha512-0y9KrdVnbMM2/vG8KfU0byhUN+EFCny9+8g202gYqSSVMonbsCfLjUO+rCci7pM0WBEtz+oK/PIwHkzxkyharA==} - engines: {node: '>=18'} - cpu: [ia32] - os: [linux] - - '@esbuild/linux-loong64@0.25.0': - resolution: {integrity: sha512-fC95c/xyNFueMhClxJmeRIj2yrSMdDfmqJnyOY4ZqsALkDrrKJfIg5NTMSzVBr5YW1jf+l7/cndBfP3MSDpoHw==} + '@esbuild/linux-loong64@0.27.2': + resolution: {integrity: sha512-lugyF1atnAT463aO6KPshVCJK5NgRnU4yb3FUumyVz+cGvZbontBgzeGFO1nF+dPueHD367a2ZXe1NtUkAjOtg==} engines: {node: '>=18'} cpu: [loong64] os: [linux] - '@esbuild/linux-loong64@0.25.12': - resolution: {integrity: sha512-h///Lr5a9rib/v1GGqXVGzjL4TMvVTv+s1DPoxQdz7l/AYv6LDSxdIwzxkrPW438oUXiDtwM10o9PmwS/6Z0Ng==} - engines: {node: '>=18'} - cpu: [loong64] - os: [linux] - - '@esbuild/linux-mips64el@0.25.0': - resolution: {integrity: sha512-nkAMFju7KDW73T1DdH7glcyIptm95a7Le8irTQNO/qtkoyypZAnjchQgooFUDQhNAy4iu08N79W4T4pMBwhPwQ==} + '@esbuild/linux-mips64el@0.27.2': + resolution: {integrity: sha512-nlP2I6ArEBewvJ2gjrrkESEZkB5mIoaTswuqNFRv/WYd+ATtUpe9Y09RnJvgvdag7he0OWgEZWhviS1OTOKixw==} engines: {node: '>=18'} cpu: [mips64el] os: [linux] - '@esbuild/linux-mips64el@0.25.12': - resolution: {integrity: sha512-iyRrM1Pzy9GFMDLsXn1iHUm18nhKnNMWscjmp4+hpafcZjrr2WbT//d20xaGljXDBYHqRcl8HnxbX6uaA/eGVw==} - engines: {node: '>=18'} - cpu: [mips64el] - os: [linux] - - '@esbuild/linux-ppc64@0.25.0': - resolution: {integrity: sha512-NhyOejdhRGS8Iwv+KKR2zTq2PpysF9XqY+Zk77vQHqNbo/PwZCzB5/h7VGuREZm1fixhs4Q/qWRSi5zmAiO4Fw==} + '@esbuild/linux-ppc64@0.27.2': + resolution: {integrity: sha512-C92gnpey7tUQONqg1n6dKVbx3vphKtTHJaNG2Ok9lGwbZil6DrfyecMsp9CrmXGQJmZ7iiVXvvZH6Ml5hL6XdQ==} engines: {node: '>=18'} cpu: [ppc64] os: [linux] - '@esbuild/linux-ppc64@0.25.12': - resolution: {integrity: sha512-9meM/lRXxMi5PSUqEXRCtVjEZBGwB7P/D4yT8UG/mwIdze2aV4Vo6U5gD3+RsoHXKkHCfSxZKzmDssVlRj1QQA==} - engines: {node: '>=18'} - cpu: [ppc64] - os: [linux] - - '@esbuild/linux-riscv64@0.25.0': - resolution: {integrity: sha512-5S/rbP5OY+GHLC5qXp1y/Mx//e92L1YDqkiBbO9TQOvuFXM+iDqUNG5XopAnXoRH3FjIUDkeGcY1cgNvnXp/kA==} + '@esbuild/linux-riscv64@0.27.2': + resolution: {integrity: sha512-B5BOmojNtUyN8AXlK0QJyvjEZkWwy/FKvakkTDCziX95AowLZKR6aCDhG7LeF7uMCXEJqwa8Bejz5LTPYm8AvA==} engines: {node: '>=18'} cpu: [riscv64] os: [linux] - '@esbuild/linux-riscv64@0.25.12': - resolution: {integrity: sha512-Zr7KR4hgKUpWAwb1f3o5ygT04MzqVrGEGXGLnj15YQDJErYu/BGg+wmFlIDOdJp0PmB0lLvxFIOXZgFRrdjR0w==} - engines: {node: '>=18'} - cpu: [riscv64] - os: [linux] - - '@esbuild/linux-s390x@0.25.0': - resolution: {integrity: sha512-XM2BFsEBz0Fw37V0zU4CXfcfuACMrppsMFKdYY2WuTS3yi8O1nFOhil/xhKTmE1nPmVyvQJjJivgDT+xh8pXJA==} + '@esbuild/linux-s390x@0.27.2': + resolution: {integrity: sha512-p4bm9+wsPwup5Z8f4EpfN63qNagQ47Ua2znaqGH6bqLlmJ4bx97Y9JdqxgGZ6Y8xVTixUnEkoKSHcpRlDnNr5w==} engines: {node: '>=18'} cpu: [s390x] os: [linux] - '@esbuild/linux-s390x@0.25.12': - resolution: {integrity: sha512-MsKncOcgTNvdtiISc/jZs/Zf8d0cl/t3gYWX8J9ubBnVOwlk65UIEEvgBORTiljloIWnBzLs4qhzPkJcitIzIg==} - engines: {node: '>=18'} - cpu: [s390x] - os: [linux] - - '@esbuild/linux-x64@0.25.0': - resolution: {integrity: sha512-9yl91rHw/cpwMCNytUDxwj2XjFpxML0y9HAOH9pNVQDpQrBxHy01Dx+vaMu0N1CKa/RzBD2hB4u//nfc+Sd3Cw==} + '@esbuild/linux-x64@0.27.2': + resolution: {integrity: sha512-uwp2Tip5aPmH+NRUwTcfLb+W32WXjpFejTIOWZFw/v7/KnpCDKG66u4DLcurQpiYTiYwQ9B7KOeMJvLCu/OvbA==} engines: {node: '>=18'} cpu: [x64] os: [linux] - '@esbuild/linux-x64@0.25.12': - resolution: {integrity: sha512-uqZMTLr/zR/ed4jIGnwSLkaHmPjOjJvnm6TVVitAa08SLS9Z0VM8wIRx7gWbJB5/J54YuIMInDquWyYvQLZkgw==} - engines: {node: '>=18'} - cpu: [x64] - os: [linux] - - '@esbuild/netbsd-arm64@0.25.0': - resolution: {integrity: sha512-RuG4PSMPFfrkH6UwCAqBzauBWTygTvb1nxWasEJooGSJ/NwRw7b2HOwyRTQIU97Hq37l3npXoZGYMy3b3xYvPw==} + '@esbuild/netbsd-arm64@0.27.2': + resolution: {integrity: sha512-Kj6DiBlwXrPsCRDeRvGAUb/LNrBASrfqAIok+xB0LxK8CHqxZ037viF13ugfsIpePH93mX7xfJp97cyDuTZ3cw==} engines: {node: '>=18'} cpu: [arm64] os: [netbsd] - '@esbuild/netbsd-arm64@0.25.12': - resolution: {integrity: sha512-xXwcTq4GhRM7J9A8Gv5boanHhRa/Q9KLVmcyXHCTaM4wKfIpWkdXiMog/KsnxzJ0A1+nD+zoecuzqPmCRyBGjg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [netbsd] - - '@esbuild/netbsd-x64@0.25.0': - resolution: {integrity: sha512-jl+qisSB5jk01N5f7sPCsBENCOlPiS/xptD5yxOx2oqQfyourJwIKLRA2yqWdifj3owQZCL2sn6o08dBzZGQzA==} + '@esbuild/netbsd-x64@0.27.2': + resolution: {integrity: sha512-HwGDZ0VLVBY3Y+Nw0JexZy9o/nUAWq9MlV7cahpaXKW6TOzfVno3y3/M8Ga8u8Yr7GldLOov27xiCnqRZf0tCA==} engines: {node: '>=18'} cpu: [x64] os: [netbsd] - '@esbuild/netbsd-x64@0.25.12': - resolution: {integrity: sha512-Ld5pTlzPy3YwGec4OuHh1aCVCRvOXdH8DgRjfDy/oumVovmuSzWfnSJg+VtakB9Cm0gxNO9BzWkj6mtO1FMXkQ==} - engines: {node: '>=18'} - cpu: [x64] - os: [netbsd] - - '@esbuild/openbsd-arm64@0.25.0': - resolution: {integrity: sha512-21sUNbq2r84YE+SJDfaQRvdgznTD8Xc0oc3p3iW/a1EVWeNj/SdUCbm5U0itZPQYRuRTW20fPMWMpcrciH2EJw==} + '@esbuild/openbsd-arm64@0.27.2': + resolution: {integrity: sha512-DNIHH2BPQ5551A7oSHD0CKbwIA/Ox7+78/AWkbS5QoRzaqlev2uFayfSxq68EkonB+IKjiuxBFoV8ESJy8bOHA==} engines: {node: '>=18'} cpu: [arm64] os: [openbsd] - '@esbuild/openbsd-arm64@0.25.12': - resolution: {integrity: sha512-fF96T6KsBo/pkQI950FARU9apGNTSlZGsv1jZBAlcLL1MLjLNIWPBkj5NlSz8aAzYKg+eNqknrUJ24QBybeR5A==} - engines: {node: '>=18'} - cpu: [arm64] - os: [openbsd] - - '@esbuild/openbsd-x64@0.25.0': - resolution: {integrity: sha512-2gwwriSMPcCFRlPlKx3zLQhfN/2WjJ2NSlg5TKLQOJdV0mSxIcYNTMhk3H3ulL/cak+Xj0lY1Ym9ysDV1igceg==} + '@esbuild/openbsd-x64@0.27.2': + resolution: {integrity: sha512-/it7w9Nb7+0KFIzjalNJVR5bOzA9Vay+yIPLVHfIQYG/j+j9VTH84aNB8ExGKPU4AzfaEvN9/V4HV+F+vo8OEg==} engines: {node: '>=18'} cpu: [x64] os: [openbsd] - '@esbuild/openbsd-x64@0.25.12': - resolution: {integrity: sha512-MZyXUkZHjQxUvzK7rN8DJ3SRmrVrke8ZyRusHlP+kuwqTcfWLyqMOE3sScPPyeIXN/mDJIfGXvcMqCgYKekoQw==} - engines: {node: '>=18'} - cpu: [x64] - os: [openbsd] - - '@esbuild/openharmony-arm64@0.25.12': - resolution: {integrity: sha512-rm0YWsqUSRrjncSXGA7Zv78Nbnw4XL6/dzr20cyrQf7ZmRcsovpcRBdhD43Nuk3y7XIoW2OxMVvwuRvk9XdASg==} + '@esbuild/openharmony-arm64@0.27.2': + resolution: {integrity: sha512-LRBbCmiU51IXfeXk59csuX/aSaToeG7w48nMwA6049Y4J4+VbWALAuXcs+qcD04rHDuSCSRKdmY63sruDS5qag==} engines: {node: '>=18'} cpu: [arm64] os: [openharmony] - '@esbuild/sunos-x64@0.25.0': - resolution: {integrity: sha512-bxI7ThgLzPrPz484/S9jLlvUAHYMzy6I0XiU1ZMeAEOBcS0VePBFxh1JjTQt3Xiat5b6Oh4x7UC7IwKQKIJRIg==} + '@esbuild/sunos-x64@0.27.2': + resolution: {integrity: sha512-kMtx1yqJHTmqaqHPAzKCAkDaKsffmXkPHThSfRwZGyuqyIeBvf08KSsYXl+abf5HDAPMJIPnbBfXvP2ZC2TfHg==} engines: {node: '>=18'} cpu: [x64] os: [sunos] - '@esbuild/sunos-x64@0.25.12': - resolution: {integrity: sha512-3wGSCDyuTHQUzt0nV7bocDy72r2lI33QL3gkDNGkod22EsYl04sMf0qLb8luNKTOmgF/eDEDP5BFNwoBKH441w==} - engines: {node: '>=18'} - cpu: [x64] - os: [sunos] - - '@esbuild/win32-arm64@0.25.0': - resolution: {integrity: sha512-ZUAc2YK6JW89xTbXvftxdnYy3m4iHIkDtK3CLce8wg8M2L+YZhIvO1DKpxrd0Yr59AeNNkTiic9YLf6FTtXWMw==} + '@esbuild/win32-arm64@0.27.2': + resolution: {integrity: sha512-Yaf78O/B3Kkh+nKABUF++bvJv5Ijoy9AN1ww904rOXZFLWVc5OLOfL56W+C8F9xn5JQZa3UX6m+IktJnIb1Jjg==} engines: {node: '>=18'} cpu: [arm64] os: [win32] - '@esbuild/win32-arm64@0.25.12': - resolution: {integrity: sha512-rMmLrur64A7+DKlnSuwqUdRKyd3UE7oPJZmnljqEptesKM8wx9J8gx5u0+9Pq0fQQW8vqeKebwNXdfOyP+8Bsg==} - engines: {node: '>=18'} - cpu: [arm64] - os: [win32] - - '@esbuild/win32-ia32@0.25.0': - resolution: {integrity: sha512-eSNxISBu8XweVEWG31/JzjkIGbGIJN/TrRoiSVZwZ6pkC6VX4Im/WV2cz559/TXLcYbcrDN8JtKgd9DJVIo8GA==} + '@esbuild/win32-ia32@0.27.2': + resolution: {integrity: sha512-Iuws0kxo4yusk7sw70Xa2E2imZU5HoixzxfGCdxwBdhiDgt9vX9VUCBhqcwY7/uh//78A1hMkkROMJq9l27oLQ==} engines: {node: '>=18'} cpu: [ia32] os: [win32] - '@esbuild/win32-ia32@0.25.12': - resolution: {integrity: sha512-HkqnmmBoCbCwxUKKNPBixiWDGCpQGVsrQfJoVGYLPT41XWF8lHuE5N6WhVia2n4o5QK5M4tYr21827fNhi4byQ==} - engines: {node: '>=18'} - cpu: [ia32] - os: [win32] - - '@esbuild/win32-x64@0.25.0': - resolution: {integrity: sha512-ZENoHJBxA20C2zFzh6AI4fT6RraMzjYw4xKWemRTRmRVtN9c5DcH9r/f2ihEkMjOW5eGgrwCslG/+Y/3bL+DHQ==} - engines: {node: '>=18'} - cpu: [x64] - os: [win32] - - '@esbuild/win32-x64@0.25.12': - resolution: {integrity: sha512-alJC0uCZpTFrSL0CCDjcgleBXPnCrEAhTBILpeAp7M/OFgoqtAetfBzX0xM00MUsVVPpVjlPuMbREqnZCXaTnA==} + '@esbuild/win32-x64@0.27.2': + resolution: {integrity: sha512-sRdU18mcKf7F+YgheI/zGf5alZatMUTKj/jNS6l744f9u3WFu4v7twcUI9vu4mknF4Y9aDlblIie0IM+5xxaqQ==} engines: {node: '>=18'} cpu: [x64] os: [win32] @@ -5161,20 +5011,15 @@ packages: esbuild-register@3.6.0: resolution: {integrity: sha512-H2/S7Pm8a9CL1uhp9OvjwrBh5Pvx0H8qVOxNu8Wed9Y7qv56MPtq+GGM8RJpq6glYJn9Wspr8uw7l55uyinNeg==} peerDependencies: - esbuild: 0.25.0 + esbuild: 0.27.2 esbuild-wasm@0.27.2: resolution: {integrity: sha512-eUTnl8eh+v8UZIZh4MrMOKDAc8Lm7+NqP3pyuTORGFY1s/o9WoiJgKnwXy+te2J3hX7iRbFSHEyig7GsPeeJyw==} engines: {node: '>=18'} hasBin: true - esbuild@0.25.0: - resolution: {integrity: sha512-BXq5mqc8ltbaN34cDqWuYKyNhX8D/Z0J1xdtdQ8UcIIIyJyz+ZMKUt58tF3SrZ85jcfN/PZYhjR5uDQAYNVbuw==} - engines: {node: '>=18'} - hasBin: true - - esbuild@0.25.12: - resolution: {integrity: sha512-bbPBYYrtZbkt6Os6FiTLCTFxvq4tt3JKall1vRwshA3fdVztsLAatFaZobhkBC8/BrPetoa0oksYoKXoG4ryJg==} + esbuild@0.27.2: + resolution: {integrity: sha512-HyNQImnsOC7X9PMNaCIeAm4ISCQXs5a5YasTXVliKv4uuBo1dKrG0A+uQS8M5eXjVMnLg3WgXaKvprHlFJQffw==} engines: {node: '>=18'} hasBin: true @@ -9899,157 +9744,82 @@ snapshots: '@es-joy/resolve.exports@1.2.0': {} - '@esbuild/aix-ppc64@0.25.0': + '@esbuild/aix-ppc64@0.27.2': optional: true - '@esbuild/aix-ppc64@0.25.12': + '@esbuild/android-arm64@0.27.2': optional: true - '@esbuild/android-arm64@0.25.0': + '@esbuild/android-arm@0.27.2': optional: true - '@esbuild/android-arm64@0.25.12': + '@esbuild/android-x64@0.27.2': optional: true - '@esbuild/android-arm@0.25.0': + '@esbuild/darwin-arm64@0.27.2': optional: true - '@esbuild/android-arm@0.25.12': + '@esbuild/darwin-x64@0.27.2': optional: true - '@esbuild/android-x64@0.25.0': + '@esbuild/freebsd-arm64@0.27.2': optional: true - '@esbuild/android-x64@0.25.12': + '@esbuild/freebsd-x64@0.27.2': optional: true - '@esbuild/darwin-arm64@0.25.0': + '@esbuild/linux-arm64@0.27.2': optional: true - '@esbuild/darwin-arm64@0.25.12': + '@esbuild/linux-arm@0.27.2': optional: true - '@esbuild/darwin-x64@0.25.0': + '@esbuild/linux-ia32@0.27.2': optional: true - '@esbuild/darwin-x64@0.25.12': + '@esbuild/linux-loong64@0.27.2': optional: true - '@esbuild/freebsd-arm64@0.25.0': + '@esbuild/linux-mips64el@0.27.2': optional: true - '@esbuild/freebsd-arm64@0.25.12': + '@esbuild/linux-ppc64@0.27.2': optional: true - '@esbuild/freebsd-x64@0.25.0': + '@esbuild/linux-riscv64@0.27.2': optional: true - '@esbuild/freebsd-x64@0.25.12': + '@esbuild/linux-s390x@0.27.2': optional: true - '@esbuild/linux-arm64@0.25.0': + '@esbuild/linux-x64@0.27.2': optional: true - '@esbuild/linux-arm64@0.25.12': + '@esbuild/netbsd-arm64@0.27.2': optional: true - '@esbuild/linux-arm@0.25.0': + '@esbuild/netbsd-x64@0.27.2': optional: true - '@esbuild/linux-arm@0.25.12': + '@esbuild/openbsd-arm64@0.27.2': optional: true - '@esbuild/linux-ia32@0.25.0': + '@esbuild/openbsd-x64@0.27.2': optional: true - '@esbuild/linux-ia32@0.25.12': + '@esbuild/openharmony-arm64@0.27.2': optional: true - '@esbuild/linux-loong64@0.25.0': + '@esbuild/sunos-x64@0.27.2': optional: true - '@esbuild/linux-loong64@0.25.12': + '@esbuild/win32-arm64@0.27.2': optional: true - '@esbuild/linux-mips64el@0.25.0': + '@esbuild/win32-ia32@0.27.2': optional: true - '@esbuild/linux-mips64el@0.25.12': - optional: true - - '@esbuild/linux-ppc64@0.25.0': - optional: true - - '@esbuild/linux-ppc64@0.25.12': - optional: true - - '@esbuild/linux-riscv64@0.25.0': - optional: true - - '@esbuild/linux-riscv64@0.25.12': - optional: true - - '@esbuild/linux-s390x@0.25.0': - optional: true - - '@esbuild/linux-s390x@0.25.12': - optional: true - - '@esbuild/linux-x64@0.25.0': - optional: true - - '@esbuild/linux-x64@0.25.12': - optional: true - - '@esbuild/netbsd-arm64@0.25.0': - optional: true - - '@esbuild/netbsd-arm64@0.25.12': - optional: true - - '@esbuild/netbsd-x64@0.25.0': - optional: true - - '@esbuild/netbsd-x64@0.25.12': - optional: true - - '@esbuild/openbsd-arm64@0.25.0': - optional: true - - '@esbuild/openbsd-arm64@0.25.12': - optional: true - - '@esbuild/openbsd-x64@0.25.0': - optional: true - - '@esbuild/openbsd-x64@0.25.12': - optional: true - - '@esbuild/openharmony-arm64@0.25.12': - optional: true - - '@esbuild/sunos-x64@0.25.0': - optional: true - - '@esbuild/sunos-x64@0.25.12': - optional: true - - '@esbuild/win32-arm64@0.25.0': - optional: true - - '@esbuild/win32-arm64@0.25.12': - optional: true - - '@esbuild/win32-ia32@0.25.0': - optional: true - - '@esbuild/win32-ia32@0.25.12': - optional: true - - '@esbuild/win32-x64@0.25.0': - optional: true - - '@esbuild/win32-x64@0.25.12': + '@esbuild/win32-x64@0.27.2': optional: true '@eslint-community/eslint-plugin-eslint-comments@4.5.0(eslint@9.39.2(jiti@1.21.7))': @@ -10638,12 +10408,12 @@ snapshots: lexical: 0.38.2 yjs: 13.6.27 - '@mdx-js/loader@3.1.1(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3))': + '@mdx-js/loader@3.1.1(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: '@mdx-js/mdx': 3.1.1 source-map: 0.7.6 optionalDependencies: - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) transitivePeerDependencies: - supports-color @@ -10729,11 +10499,11 @@ snapshots: dependencies: fast-glob: 3.3.1 - '@next/mdx@15.5.9(@mdx-js/loader@3.1.1(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.7)(react@19.2.3))': + '@next/mdx@15.5.9(@mdx-js/loader@3.1.1(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)))(@mdx-js/react@3.1.1(@types/react@19.2.7)(react@19.2.3))': dependencies: source-map: 0.7.6 optionalDependencies: - '@mdx-js/loader': 3.1.1(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + '@mdx-js/loader': 3.1.1(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) '@mdx-js/react': 3.1.1(@types/react@19.2.7)(react@19.2.3) '@next/swc-darwin-arm64@15.5.7': @@ -11006,7 +10776,7 @@ snapshots: playwright: 1.57.0 optional: true - '@pmmmwh/react-refresh-webpack-plugin@0.5.17(react-refresh@0.14.2)(type-fest@4.2.0)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3))': + '@pmmmwh/react-refresh-webpack-plugin@0.5.17(react-refresh@0.14.2)(type-fest@4.2.0)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: ansi-html: 0.0.9 core-js-pure: 3.47.0 @@ -11016,7 +10786,7 @@ snapshots: react-refresh: 0.14.2 schema-utils: 4.3.3 source-map: 0.7.6 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) optionalDependencies: type-fest: 4.2.0 webpack-hot-middleware: 2.26.1 @@ -11547,22 +11317,22 @@ snapshots: storybook: 9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)) ts-dedent: 2.2.0 - '@storybook/builder-webpack5@9.1.13(esbuild@0.25.0)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3)': + '@storybook/builder-webpack5@9.1.13(esbuild@0.27.2)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3)': dependencies: '@storybook/core-webpack': 9.1.13(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2))) case-sensitive-paths-webpack-plugin: 2.4.0 cjs-module-lexer: 1.4.3 - css-loader: 6.11.0(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + css-loader: 6.11.0(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) es-module-lexer: 1.7.0 - fork-ts-checker-webpack-plugin: 8.0.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) - html-webpack-plugin: 5.6.5(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + fork-ts-checker-webpack-plugin: 8.0.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) + html-webpack-plugin: 5.6.5(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) magic-string: 0.30.21 storybook: 9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)) - style-loader: 3.3.4(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) - terser-webpack-plugin: 5.3.15(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + style-loader: 3.3.4(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) + terser-webpack-plugin: 5.3.15(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) ts-dedent: 2.2.0 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) - webpack-dev-middleware: 6.1.3(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) + webpack-dev-middleware: 6.1.3(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) webpack-hot-middleware: 2.26.1 webpack-virtual-modules: 0.6.2 optionalDependencies: @@ -11591,7 +11361,7 @@ snapshots: react: 19.2.3 react-dom: 19.2.3(react@19.2.3) - '@storybook/nextjs@9.1.13(esbuild@0.25.0)(next@15.5.9(@babel/core@7.28.5)(@playwright/test@1.57.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(type-fest@4.2.0)(typescript@5.9.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3))': + '@storybook/nextjs@9.1.13(esbuild@0.27.2)(next@15.5.9(@babel/core@7.28.5)(@playwright/test@1.57.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(type-fest@4.2.0)(typescript@5.9.3)(uglify-js@3.19.3)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: '@babel/core': 7.28.5 '@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.5) @@ -11606,33 +11376,33 @@ snapshots: '@babel/preset-react': 7.28.5(@babel/core@7.28.5) '@babel/preset-typescript': 7.28.5(@babel/core@7.28.5) '@babel/runtime': 7.28.4 - '@pmmmwh/react-refresh-webpack-plugin': 0.5.17(react-refresh@0.14.2)(type-fest@4.2.0)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) - '@storybook/builder-webpack5': 9.1.13(esbuild@0.25.0)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3) - '@storybook/preset-react-webpack': 9.1.13(esbuild@0.25.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3) + '@pmmmwh/react-refresh-webpack-plugin': 0.5.17(react-refresh@0.14.2)(type-fest@4.2.0)(webpack-hot-middleware@2.26.1)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/builder-webpack5': 9.1.13(esbuild@0.27.2)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3) + '@storybook/preset-react-webpack': 9.1.13(esbuild@0.27.2)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3) '@storybook/react': 9.1.13(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3) '@types/semver': 7.7.1 - babel-loader: 9.2.1(@babel/core@7.28.5)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) - css-loader: 6.11.0(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + babel-loader: 9.2.1(@babel/core@7.28.5)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) + css-loader: 6.11.0(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) image-size: 2.0.2 loader-utils: 3.3.1 next: 15.5.9(@babel/core@7.28.5)(@playwright/test@1.57.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(sass@1.95.0) - node-polyfill-webpack-plugin: 2.0.1(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + node-polyfill-webpack-plugin: 2.0.1(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) postcss: 8.5.6 - postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) react: 19.2.3 react-dom: 19.2.3(react@19.2.3) react-refresh: 0.14.2 resolve-url-loader: 5.0.0 - sass-loader: 16.0.6(sass@1.95.0)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + sass-loader: 16.0.6(sass@1.95.0)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) semver: 7.7.3 storybook: 9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)) - style-loader: 3.3.4(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + style-loader: 3.3.4(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) styled-jsx: 5.1.7(@babel/core@7.28.5)(react@19.2.3) tsconfig-paths: 4.2.0 tsconfig-paths-webpack-plugin: 4.2.0 optionalDependencies: typescript: 5.9.3 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) transitivePeerDependencies: - '@rspack/core' - '@swc/core' @@ -11651,10 +11421,10 @@ snapshots: - webpack-hot-middleware - webpack-plugin-serve - '@storybook/preset-react-webpack@9.1.13(esbuild@0.25.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3)': + '@storybook/preset-react-webpack@9.1.13(esbuild@0.27.2)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))(typescript@5.9.3)(uglify-js@3.19.3)': dependencies: '@storybook/core-webpack': 9.1.13(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2))) - '@storybook/react-docgen-typescript-plugin': 1.0.6--canary.9.0c3f3b7.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + '@storybook/react-docgen-typescript-plugin': 1.0.6--canary.9.0c3f3b7.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) '@types/semver': 7.7.1 find-up: 7.0.0 magic-string: 0.30.21 @@ -11665,7 +11435,7 @@ snapshots: semver: 7.7.3 storybook: 9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)) tsconfig-paths: 4.2.0 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) optionalDependencies: typescript: 5.9.3 transitivePeerDependencies: @@ -11675,7 +11445,7 @@ snapshots: - uglify-js - webpack-cli - '@storybook/react-docgen-typescript-plugin@1.0.6--canary.9.0c3f3b7.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3))': + '@storybook/react-docgen-typescript-plugin@1.0.6--canary.9.0c3f3b7.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: debug: 4.4.3 endent: 2.1.0 @@ -11685,7 +11455,7 @@ snapshots: react-docgen-typescript: 2.4.0(typescript@5.9.3) tslib: 2.8.1 typescript: 5.9.3 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) transitivePeerDependencies: - supports-color @@ -12865,12 +12635,12 @@ snapshots: postcss: 8.5.6 postcss-value-parser: 4.2.0 - babel-loader@9.2.1(@babel/core@7.28.5)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + babel-loader@9.2.1(@babel/core@7.28.5)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@babel/core': 7.28.5 find-cache-dir: 4.0.0 schema-utils: 4.3.3 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) babel-plugin-polyfill-corejs2@0.4.14(@babel/core@7.28.5): dependencies: @@ -13346,7 +13116,7 @@ snapshots: randombytes: 2.1.0 randomfill: 1.0.4 - css-loader@6.11.0(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + css-loader@6.11.0(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: icss-utils: 5.1.0(postcss@8.5.6) postcss: 8.5.6 @@ -13357,7 +13127,7 @@ snapshots: postcss-value-parser: 4.2.0 semver: 7.7.3 optionalDependencies: - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) css-mediaquery@0.1.2: {} @@ -13778,71 +13548,43 @@ snapshots: esast-util-from-estree: 2.0.0 vfile-message: 4.0.3 - esbuild-register@3.6.0(esbuild@0.25.0): + esbuild-register@3.6.0(esbuild@0.27.2): dependencies: debug: 4.4.3 - esbuild: 0.25.0 + esbuild: 0.27.2 transitivePeerDependencies: - supports-color esbuild-wasm@0.27.2: {} - esbuild@0.25.0: + esbuild@0.27.2: optionalDependencies: - '@esbuild/aix-ppc64': 0.25.0 - '@esbuild/android-arm': 0.25.0 - '@esbuild/android-arm64': 0.25.0 - '@esbuild/android-x64': 0.25.0 - '@esbuild/darwin-arm64': 0.25.0 - '@esbuild/darwin-x64': 0.25.0 - '@esbuild/freebsd-arm64': 0.25.0 - '@esbuild/freebsd-x64': 0.25.0 - '@esbuild/linux-arm': 0.25.0 - '@esbuild/linux-arm64': 0.25.0 - '@esbuild/linux-ia32': 0.25.0 - '@esbuild/linux-loong64': 0.25.0 - '@esbuild/linux-mips64el': 0.25.0 - '@esbuild/linux-ppc64': 0.25.0 - '@esbuild/linux-riscv64': 0.25.0 - '@esbuild/linux-s390x': 0.25.0 - '@esbuild/linux-x64': 0.25.0 - '@esbuild/netbsd-arm64': 0.25.0 - '@esbuild/netbsd-x64': 0.25.0 - '@esbuild/openbsd-arm64': 0.25.0 - '@esbuild/openbsd-x64': 0.25.0 - '@esbuild/sunos-x64': 0.25.0 - '@esbuild/win32-arm64': 0.25.0 - '@esbuild/win32-ia32': 0.25.0 - '@esbuild/win32-x64': 0.25.0 - - esbuild@0.25.12: - optionalDependencies: - '@esbuild/aix-ppc64': 0.25.12 - '@esbuild/android-arm': 0.25.12 - '@esbuild/android-arm64': 0.25.12 - '@esbuild/android-x64': 0.25.12 - '@esbuild/darwin-arm64': 0.25.12 - '@esbuild/darwin-x64': 0.25.12 - '@esbuild/freebsd-arm64': 0.25.12 - '@esbuild/freebsd-x64': 0.25.12 - '@esbuild/linux-arm': 0.25.12 - '@esbuild/linux-arm64': 0.25.12 - '@esbuild/linux-ia32': 0.25.12 - '@esbuild/linux-loong64': 0.25.12 - '@esbuild/linux-mips64el': 0.25.12 - '@esbuild/linux-ppc64': 0.25.12 - '@esbuild/linux-riscv64': 0.25.12 - '@esbuild/linux-s390x': 0.25.12 - '@esbuild/linux-x64': 0.25.12 - '@esbuild/netbsd-arm64': 0.25.12 - '@esbuild/netbsd-x64': 0.25.12 - '@esbuild/openbsd-arm64': 0.25.12 - '@esbuild/openbsd-x64': 0.25.12 - '@esbuild/openharmony-arm64': 0.25.12 - '@esbuild/sunos-x64': 0.25.12 - '@esbuild/win32-arm64': 0.25.12 - '@esbuild/win32-ia32': 0.25.12 - '@esbuild/win32-x64': 0.25.12 + '@esbuild/aix-ppc64': 0.27.2 + '@esbuild/android-arm': 0.27.2 + '@esbuild/android-arm64': 0.27.2 + '@esbuild/android-x64': 0.27.2 + '@esbuild/darwin-arm64': 0.27.2 + '@esbuild/darwin-x64': 0.27.2 + '@esbuild/freebsd-arm64': 0.27.2 + '@esbuild/freebsd-x64': 0.27.2 + '@esbuild/linux-arm': 0.27.2 + '@esbuild/linux-arm64': 0.27.2 + '@esbuild/linux-ia32': 0.27.2 + '@esbuild/linux-loong64': 0.27.2 + '@esbuild/linux-mips64el': 0.27.2 + '@esbuild/linux-ppc64': 0.27.2 + '@esbuild/linux-riscv64': 0.27.2 + '@esbuild/linux-s390x': 0.27.2 + '@esbuild/linux-x64': 0.27.2 + '@esbuild/netbsd-arm64': 0.27.2 + '@esbuild/netbsd-x64': 0.27.2 + '@esbuild/openbsd-arm64': 0.27.2 + '@esbuild/openbsd-x64': 0.27.2 + '@esbuild/openharmony-arm64': 0.27.2 + '@esbuild/sunos-x64': 0.27.2 + '@esbuild/win32-arm64': 0.27.2 + '@esbuild/win32-ia32': 0.27.2 + '@esbuild/win32-x64': 0.27.2 escalade@3.2.0: {} @@ -14456,7 +14198,7 @@ snapshots: cross-spawn: 7.0.6 signal-exit: 4.1.0 - fork-ts-checker-webpack-plugin@8.0.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + fork-ts-checker-webpack-plugin@8.0.0(typescript@5.9.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@babel/code-frame': 7.27.1 chalk: 4.1.2 @@ -14471,7 +14213,7 @@ snapshots: semver: 7.7.3 tapable: 2.3.0 typescript: 5.9.3 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) format@0.2.2: {} @@ -14786,7 +14528,7 @@ snapshots: html-void-elements@3.0.0: {} - html-webpack-plugin@5.6.5(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + html-webpack-plugin@5.6.5(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@types/html-minifier-terser': 6.1.0 html-minifier-terser: 6.1.0 @@ -14794,7 +14536,7 @@ snapshots: pretty-error: 4.0.0 tapable: 2.3.0 optionalDependencies: - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) htmlparser2@6.1.0: dependencies: @@ -15944,7 +15686,7 @@ snapshots: node-addon-api@7.1.1: optional: true - node-polyfill-webpack-plugin@2.0.1(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + node-polyfill-webpack-plugin@2.0.1(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: assert: '@nolyfill/assert@1.0.26' browserify-zlib: 0.2.0 @@ -15971,7 +15713,7 @@ snapshots: url: 0.11.4 util: 0.12.5 vm-browserify: 1.1.2 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) node-releases@2.0.27: {} @@ -16287,14 +16029,14 @@ snapshots: tsx: 4.21.0 yaml: 2.8.2 - postcss-loader@8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + postcss-loader@8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: cosmiconfig: 9.0.0(typescript@5.9.3) jiti: 2.6.1 postcss: 8.5.6 semver: 7.7.3 optionalDependencies: - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) transitivePeerDependencies: - typescript @@ -16612,7 +16354,7 @@ snapshots: '@rollup/pluginutils': 5.3.0(rollup@4.53.5) '@types/node': 20.19.26 bippy: 0.3.34(@types/react@19.2.7)(react@19.2.3) - esbuild: 0.25.12 + esbuild: 0.27.2 estree-walker: 3.0.3 kleur: 4.1.5 mri: 1.2.0 @@ -16995,12 +16737,12 @@ snapshots: safe-buffer@5.2.1: {} - sass-loader@16.0.6(sass@1.95.0)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + sass-loader@16.0.6(sass@1.95.0)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: neo-async: 2.6.2 optionalDependencies: sass: 1.95.0 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) sass@1.95.0: dependencies: @@ -17225,8 +16967,8 @@ snapshots: '@vitest/mocker': 3.2.4(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)) '@vitest/spy': 3.2.4 better-opn: 3.0.2 - esbuild: 0.25.0 - esbuild-register: 3.6.0(esbuild@0.25.0) + esbuild: 0.27.2 + esbuild-register: 3.6.0(esbuild@0.27.2) recast: 0.23.11 semver: 7.7.3 ws: 8.18.3 @@ -17300,9 +17042,9 @@ snapshots: strip-json-comments@5.0.3: {} - style-loader@3.3.4(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + style-loader@3.3.4(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) style-to-js@1.1.21: dependencies: @@ -17405,16 +17147,16 @@ snapshots: readable-stream: 3.6.2 optional: true - terser-webpack-plugin@5.3.15(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + terser-webpack-plugin@5.3.15(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: '@jridgewell/trace-mapping': 0.3.31 jest-worker: 27.5.1 schema-utils: 4.3.3 serialize-javascript: 6.0.2 terser: 5.44.1 - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) optionalDependencies: - esbuild: 0.25.0 + esbuild: 0.27.2 uglify-js: 3.19.3 terser@5.44.1: @@ -17542,7 +17284,7 @@ snapshots: tsx@4.21.0: dependencies: - esbuild: 0.25.12 + esbuild: 0.27.2 get-tsconfig: 4.13.0 optionalDependencies: fsevents: 2.3.3 @@ -17750,7 +17492,7 @@ snapshots: vite@6.4.1(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2): dependencies: - esbuild: 0.25.12 + esbuild: 0.27.2 fdir: 6.5.0(picomatch@4.0.3) picomatch: 4.0.3 postcss: 8.5.6 @@ -17767,7 +17509,7 @@ snapshots: vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2): dependencies: - esbuild: 0.25.12 + esbuild: 0.27.2 fdir: 6.5.0(picomatch@4.0.3) picomatch: 4.0.3 postcss: 8.5.6 @@ -17892,7 +17634,7 @@ snapshots: - bufferutil - utf-8-validate - webpack-dev-middleware@6.1.3(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)): + webpack-dev-middleware@6.1.3(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: colorette: 2.0.20 memfs: 3.5.3 @@ -17900,7 +17642,7 @@ snapshots: range-parser: 1.2.1 schema-utils: 4.3.3 optionalDependencies: - webpack: 5.103.0(esbuild@0.25.0)(uglify-js@3.19.3) + webpack: 5.103.0(esbuild@0.27.2)(uglify-js@3.19.3) webpack-hot-middleware@2.26.1: dependencies: @@ -17912,7 +17654,7 @@ snapshots: webpack-virtual-modules@0.6.2: {} - webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3): + webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3): dependencies: '@types/eslint-scope': 3.7.7 '@types/estree': 1.0.8 @@ -17936,7 +17678,7 @@ snapshots: neo-async: 2.6.2 schema-utils: 4.3.3 tapable: 2.3.0 - terser-webpack-plugin: 5.3.15(esbuild@0.25.0)(uglify-js@3.19.3)(webpack@5.103.0(esbuild@0.25.0)(uglify-js@3.19.3)) + terser-webpack-plugin: 5.3.15(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.103.0(esbuild@0.27.2)(uglify-js@3.19.3)) watchpack: 2.4.4 webpack-sources: 3.3.3 transitivePeerDependencies: From 51ea87ab85be8ab7d7c26558f4bb3cfa6fd9928d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Mon, 12 Jan 2026 15:57:40 +0800 Subject: [PATCH 03/29] feat: clear free plan workflow run logs (#29494) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/.env.example | 1 + api/commands.py | 57 +++ api/configs/feature/__init__.py | 4 + api/core/workflow/enums.py | 4 + api/extensions/ext_celery.py | 7 + api/extensions/ext_commands.py | 2 + .../versions/2026_01_09_1630-905527cc8fd3_.py | 30 ++ api/models/workflow.py | 1 + .../api_workflow_run_repository.py | 43 ++- ..._api_workflow_node_execution_repository.py | 147 +++++++- .../sqlalchemy_api_workflow_run_repository.py | 179 +++++++++- ...alchemy_workflow_trigger_log_repository.py | 38 +- .../workflow_trigger_log_repository.py | 12 + api/schedule/clean_workflow_runs_task.py | 43 +++ api/services/retention/__init__.py | 0 .../retention/workflow_run/__init__.py | 0 ...ear_free_plan_expired_workflow_run_logs.py | 301 ++++++++++++++++ ..._sqlalchemy_api_workflow_run_repository.py | 92 +++++ ...alchemy_workflow_trigger_log_repository.py | 31 ++ ...ear_free_plan_expired_workflow_run_logs.py | 327 ++++++++++++++++++ docker/.env.example | 1 + docker/docker-compose.yaml | 1 + 22 files changed, 1312 insertions(+), 9 deletions(-) create mode 100644 api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py create mode 100644 api/schedule/clean_workflow_runs_task.py create mode 100644 api/services/retention/__init__.py create mode 100644 api/services/retention/workflow_run/__init__.py create mode 100644 api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py create mode 100644 api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py create mode 100644 api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py diff --git a/api/.env.example b/api/.env.example index 44d770ed70..8099c4a42a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -589,6 +589,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true diff --git a/api/commands.py b/api/commands.py index 7ebf5b4874..e24b1826ee 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,4 +1,5 @@ import base64 +import datetime import json import logging import secrets @@ -45,6 +46,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_service import PluginService +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from tasks.remove_app_and_related_data_task import delete_draft_variables_batch logger = logging.getLogger(__name__) @@ -852,6 +854,61 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[ click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) +@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") +@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.") +@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option( + "--dry-run", + is_flag=True, + help="Preview cleanup results without deleting any workflow run data.", +) +def clean_workflow_runs( + days: int, + batch_size: int, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + dry_run: bool, +): + """ + Clean workflow runs and related workflow data for free tenants. + """ + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) + + WorkflowRunCleanup( + days=days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + dry_run=dry_run, + ).run() + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Workflow run cleanup completed. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 6a04171d2d..cf855b1cc0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1101,6 +1101,10 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable clean messages task", default=False, ) + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field( + description="Enable scheduled workflow run cleanup task", + default=False, + ) ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( description="Enable mail clean document notify task", default=False, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index c08b62a253..bb3b13e8c6 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum): def is_ended(self) -> bool: return self in _END_STATE + @classmethod + def ended_values(cls) -> list[str]: + return [status.value for status in _END_STATE] + _END_STATE = frozenset( [ diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 2fbab001d0..08cf96c1c1 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -163,6 +163,13 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", "schedule": crontab(minute="0", hour="2"), } + if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK: + # for saas only + imports.append("schedule.clean_workflow_runs_task") + beat_schedule["clean_workflow_runs_task"] = { + "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task", + "schedule": crontab(minute="0", hour="0"), + } if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: imports.append("schedule.workflow_schedule_task") beat_schedule["workflow_schedule_task"] = { diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index daa3756dba..c32130d377 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, @@ -56,6 +57,7 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + clean_workflow_runs, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py new file mode 100644 index 0000000000..7e0cc8ec9d --- /dev/null +++ b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py @@ -0,0 +1,30 @@ +"""add workflow_run_created_at_id_idx + +Revision ID: 905527cc8fd3 +Revises: 7df29de0f6be +Create Date: 2025-01-09 16:30:02.462084 + +""" +from alembic import op +import models as models + +# revision identifiers, used by Alembic. +revision = '905527cc8fd3' +down_revision = '7df29de0f6be' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_created_at_id_idx') + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index a18939523b..072c6100b5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -597,6 +597,7 @@ class WorkflowRun(Base): __table_args__ = ( sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"), sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + sa.Index("workflow_run_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index fd547c78ba..1a2b84fdf9 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -34,11 +34,14 @@ Example: ``` """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from sqlalchemy.orm import Session + from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import WorkflowType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -253,6 +256,44 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_runs_batch_by_time_range( + self, + start_from: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + run_types: Sequence[WorkflowType] | None = None, + tenant_ids: Sequence[str] | None = None, + ) -> Sequence[WorkflowRun]: + """ + Fetch ended workflow runs in a time window for archival and clean batching. + """ + ... + + def delete_runs_with_related( + self, + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + """ + Delete workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons). + """ + ... + + def count_runs_with_related( + self, + runs: Sequence[WorkflowRun], + count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + """ + Count workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons) without deleting data. + """ + ... + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 7e2173acdd..2de3a15d65 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -7,13 +7,18 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. from collections.abc import Sequence from datetime import datetime -from typing import cast +from typing import TypedDict, cast -from sqlalchemy import asc, delete, desc, select +from sqlalchemy import asc, delete, desc, func, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from models.workflow import WorkflowNodeExecutionModel +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import ( + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowNodeExecutionTriggeredFrom, +) from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -44,6 +49,26 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut """ self._session_maker = session_maker + @staticmethod + def _map_run_triggered_from_to_node_triggered_from(triggered_from: str) -> str: + """ + Map workflow run triggered_from values to workflow node execution triggered_from values. + """ + if triggered_from in { + WorkflowRunTriggeredFrom.APP_RUN.value, + WorkflowRunTriggeredFrom.DEBUGGING.value, + WorkflowRunTriggeredFrom.SCHEDULE.value, + WorkflowRunTriggeredFrom.PLUGIN.value, + WorkflowRunTriggeredFrom.WEBHOOK.value, + }: + return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + if triggered_from in { + WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, + WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value, + }: + return WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value + return "" + def get_node_last_execution( self, tenant_id: str, @@ -290,3 +315,119 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut result = cast(CursorResult, session.execute(stmt)) session.commit() return result.rowcount + + class RunContext(TypedDict): + run_id: str + tenant_id: str + app_id: str + workflow_id: str + triggered_from: str + + @staticmethod + def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]: + """ + Delete node executions (and offloads) for the given workflow runs using indexed columns. + + Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id) + by filtering on those columns with tuple IN. + """ + if not runs: + return 0, 0 + + tuple_values = [ + ( + run["tenant_id"], + run["app_id"], + run["workflow_id"], + DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from( + run["triggered_from"] + ), + run["run_id"], + ) + for run in runs + ] + + node_execution_ids = session.scalars( + select(WorkflowNodeExecutionModel.id).where( + tuple_( + WorkflowNodeExecutionModel.tenant_id, + WorkflowNodeExecutionModel.app_id, + WorkflowNodeExecutionModel.workflow_id, + WorkflowNodeExecutionModel.triggered_from, + WorkflowNodeExecutionModel.workflow_run_id, + ).in_(tuple_values) + ) + ).all() + + if not node_execution_ids: + return 0, 0 + + offloads_deleted = ( + cast( + CursorResult, + session.execute( + delete(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + ), + ).rowcount + or 0 + ) + + node_executions_deleted = ( + cast( + CursorResult, + session.execute( + delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) + ), + ).rowcount + or 0 + ) + + return node_executions_deleted, offloads_deleted + + @staticmethod + def count_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]: + """ + Count node executions (and offloads) for the given workflow runs using indexed columns. + """ + if not runs: + return 0, 0 + + tuple_values = [ + ( + run["tenant_id"], + run["app_id"], + run["workflow_id"], + DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from( + run["triggered_from"] + ), + run["run_id"], + ) + for run in runs + ] + tuple_filter = tuple_( + WorkflowNodeExecutionModel.tenant_id, + WorkflowNodeExecutionModel.app_id, + WorkflowNodeExecutionModel.workflow_id, + WorkflowNodeExecutionModel.triggered_from, + WorkflowNodeExecutionModel.workflow_run_id, + ).in_(tuple_values) + + node_executions_count = ( + session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(tuple_filter)) or 0 + ) + offloads_count = ( + session.scalar( + select(func.count()) + .select_from(WorkflowNodeExecutionOffload) + .join( + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id, + ) + .where(tuple_filter) + ) + or 0 + ) + + return int(node_executions_count), int(offloads_count) diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index b172c6a3ac..9d2d06e99f 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -21,7 +21,7 @@ Implementation Notes: import logging import uuid -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from decimal import Decimal from typing import Any, cast @@ -32,7 +32,7 @@ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause -from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.enums import WorkflowExecutionStatus, WorkflowType from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date @@ -40,8 +40,14 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun +from models.workflow import ( + WorkflowAppLog, + WorkflowPauseReason, + WorkflowRun, +) +from models.workflow import ( + WorkflowPause as WorkflowPauseModel, +) from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( @@ -314,6 +320,171 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted + def get_runs_batch_by_time_range( + self, + start_from: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + run_types: Sequence[WorkflowType] | None = None, + tenant_ids: Sequence[str] | None = None, + ) -> Sequence[WorkflowRun]: + """ + Fetch ended workflow runs in a time window for archival and clean batching. + + Query scope: + - created_at in [start_from, end_before) + - type in run_types (when provided) + - status is an ended state + - optional tenant_id filter and cursor (last_seen) for pagination + """ + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.created_at < end_before, + WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()), + ) + .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) + .limit(batch_size) + ) + if run_types is not None: + if not run_types: + return [] + stmt = stmt.where(WorkflowRun.type.in_(run_types)) + + if start_from: + stmt = stmt.where(WorkflowRun.created_at >= start_from) + + if tenant_ids: + stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids)) + + if last_seen: + stmt = stmt.where( + or_( + WorkflowRun.created_at > last_seen[0], + and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]), + ) + ) + + return session.scalars(stmt).all() + + def delete_runs_with_related( + self, + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + if not runs: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + run_ids = [run.id for run in runs] + if delete_node_executions: + node_executions_deleted, offloads_deleted = delete_node_executions(session, runs) + else: + node_executions_deleted, offloads_deleted = 0, 0 + + app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) + app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 + + pause_ids = session.scalars( + select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + ).all() + pause_reasons_deleted = 0 + pauses_deleted = 0 + + if pause_ids: + pause_reasons_result = session.execute( + delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + ) + pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0 + pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids))) + pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 + + trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0 + + runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))) + runs_deleted = cast(CursorResult, runs_result).rowcount or 0 + + session.commit() + + return { + "runs": runs_deleted, + "node_executions": node_executions_deleted, + "offloads": offloads_deleted, + "app_logs": app_logs_deleted, + "trigger_logs": trigger_logs_deleted, + "pauses": pauses_deleted, + "pause_reasons": pause_reasons_deleted, + } + + def count_runs_with_related( + self, + runs: Sequence[WorkflowRun], + count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + if not runs: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + run_ids = [run.id for run in runs] + if count_node_executions: + node_executions_count, offloads_count = count_node_executions(session, runs) + else: + node_executions_count, offloads_count = 0, 0 + + app_logs_count = ( + session.scalar( + select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)) + ) + or 0 + ) + + pause_ids = session.scalars( + select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + ).all() + pauses_count = len(pause_ids) + pause_reasons_count = 0 + if pause_ids: + pause_reasons_count = ( + session.scalar( + select(func.count()) + .select_from(WorkflowPauseReason) + .where(WorkflowPauseReason.pause_id.in_(pause_ids)) + ) + or 0 + ) + + trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0 + + return { + "runs": len(runs), + "node_executions": node_executions_count, + "offloads": offloads_count, + "app_logs": int(app_logs_count), + "trigger_logs": trigger_logs_count, + "pauses": pauses_count, + "pause_reasons": int(pause_reasons_count), + } + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index 0d67e286b0..ebd3745d18 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository. from collections.abc import Sequence from datetime import UTC, datetime, timedelta +from typing import cast -from sqlalchemy import and_, select +from sqlalchemy import and_, delete, func, select +from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from models.enums import WorkflowTriggerStatus @@ -84,3 +86,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): ) return list(self.session.scalars(query).all()) + + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Delete trigger logs associated with the given workflow run ids. + + Args: + run_ids: Collection of workflow run identifiers. + + Returns: + Number of rows deleted. + """ + if not run_ids: + return 0 + + result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))) + return cast(CursorResult, result).rowcount or 0 + + def count_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Count trigger logs associated with the given workflow run ids. + + Args: + run_ids: Collection of workflow run identifiers. + + Returns: + Number of rows matched. + """ + if not run_ids: + return 0 + + count = self.session.scalar( + select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)) + ) + return int(count or 0) diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py index 138b8779ac..b0009e398d 100644 --- a/api/repositories/workflow_trigger_log_repository.py +++ b/api/repositories/workflow_trigger_log_repository.py @@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol): A sequence of recent WorkflowTriggerLog instances """ ... + + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Delete trigger logs for workflow run IDs. + + Args: + run_ids: Workflow run IDs to delete + + Returns: + Number of rows deleted + """ + ... diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py new file mode 100644 index 0000000000..9f5bf8e150 --- /dev/null +++ b/api/schedule/clean_workflow_runs_task.py @@ -0,0 +1,43 @@ +from datetime import UTC, datetime + +import click + +import app +from configs import dify_config +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +@app.celery.task(queue="retention") +def clean_workflow_runs_task() -> None: + """ + Scheduled cleanup for workflow runs and related records (sandbox tenants only). + """ + click.echo( + click.style( + ( + "Scheduled workflow run cleanup starting: " + f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, " + f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}" + ), + fg="green", + ) + ) + + start_time = datetime.now(UTC) + + WorkflowRunCleanup( + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + start_from=None, + end_before=None, + ).run() + + end_time = datetime.now(UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) diff --git a/api/services/retention/__init__.py b/api/services/retention/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..2213169510 --- /dev/null +++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,301 @@ +import datetime +import logging +from collections.abc import Iterable, Sequence + +import click +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +class WorkflowRunCleanup: + def __init__( + self, + days: int, + batch_size: int, + start_from: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + workflow_run_repo: APIWorkflowRunRepository | None = None, + dry_run: bool = False, + ): + if (start_from is None) ^ (end_before is None): + raise ValueError("start_from and end_before must be both set or both omitted.") + + computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days) + self.window_start = start_from + self.window_end = end_before or computed_cutoff + + if self.window_start and self.window_end <= self.window_start: + raise ValueError("end_before must be greater than start_from.") + + if batch_size <= 0: + raise ValueError("batch_size must be greater than 0.") + + self.batch_size = batch_size + self._cleanup_whitelist: set[str] | None = None + self.dry_run = dry_run + self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD + self.workflow_run_repo: APIWorkflowRunRepository + if workflow_run_repo: + self.workflow_run_repo = workflow_run_repo + else: + # Lazy import to avoid circular dependencies during module import + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + def run(self) -> None: + click.echo( + click.style( + f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs " + f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}" + f"{self.window_end.isoformat()} (batch={self.batch_size})", + fg="white", + ) + ) + if self.dry_run: + click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow")) + + total_runs_deleted = 0 + total_runs_targeted = 0 + related_totals = self._empty_related_counts() if self.dry_run else None + batch_index = 0 + last_seen: tuple[datetime.datetime, str] | None = None + + while True: + run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( + start_from=self.window_start, + end_before=self.window_end, + last_seen=last_seen, + batch_size=self.batch_size, + ) + if not run_rows: + break + + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) + tenant_ids = {row.tenant_id for row in run_rows} + free_tenants = self._filter_free_tenants(tenant_ids) + free_runs = [row for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_runs) + + if not free_runs: + click.echo( + click.style( + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)", + fg="yellow", + ) + ) + continue + + total_runs_targeted += len(free_runs) + + if self.dry_run: + batch_counts = self.workflow_run_repo.count_runs_with_related( + free_runs, + count_node_executions=self._count_node_executions, + count_trigger_logs=self._count_trigger_logs, + ) + if related_totals is not None: + for key in related_totals: + related_totals[key] += batch_counts.get(key, 0) + sample_ids = ", ".join(run.id for run in free_runs[:5]) + click.echo( + click.style( + f"[batch #{batch_index}] would delete {len(free_runs)} runs " + f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown", + fg="yellow", + ) + ) + continue + + try: + counts = self.workflow_run_repo.delete_runs_with_related( + free_runs, + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + except Exception: + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] + click.echo( + click.style( + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", + ) + ) + + if self.dry_run: + if self.window_start: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"before {self.window_end.isoformat()}" + ) + if related_totals is not None: + summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}" + summary_color = "yellow" + else: + if self.window_start: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}" + ) + summary_color = "white" + + click.echo(click.style(summary_message, fg=summary_color)) + + def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: + tenant_id_list = list(tenant_ids) + + if not dify_config.BILLING_ENABLED: + return set(tenant_id_list) + + if not tenant_id_list: + return set() + + cleanup_whitelist = self._get_cleanup_whitelist() + + try: + bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list) + except Exception: + bulk_info = {} + logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list) + + eligible_free_tenants: set[str] = set() + for tenant_id in tenant_id_list: + if tenant_id in cleanup_whitelist: + continue + + info = bulk_info.get(tenant_id) + if info is None: + logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id) + continue + + if info.get("plan") != CloudPlan.SANDBOX: + continue + + if self._is_within_grace_period(tenant_id, info): + continue + + eligible_free_tenants.add(tenant_id) + + return eligible_free_tenants + + def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None: + if expiration_value < 0: + return None + + try: + return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC) + except (OverflowError, OSError, ValueError): + logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id) + return None + + def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool: + if self.free_plan_grace_period_days <= 0: + return False + + expiration_value = info.get("expiration_date", -1) + expiration_at = self._expiration_datetime(tenant_id, expiration_value) + if expiration_at is None: + return False + + grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days) + return datetime.datetime.now(datetime.UTC) < grace_deadline + + def _get_cleanup_whitelist(self) -> set[str]: + if self._cleanup_whitelist is not None: + return self._cleanup_whitelist + + if not dify_config.BILLING_ENABLED: + self._cleanup_whitelist = set() + return self._cleanup_whitelist + + try: + whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist() + except Exception: + logger.exception("Failed to fetch cleanup whitelist from billing service") + whitelist_ids = [] + + self._cleanup_whitelist = set(whitelist_ids) + return self._cleanup_whitelist + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.count_by_run_ids(run_ids) + + @staticmethod + def _build_run_contexts( + runs: Sequence[WorkflowRun], + ) -> list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext]: + return [ + { + "run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "triggered_from": run.triggered_from, + } + for run in runs + ] + + @staticmethod + def _empty_related_counts() -> dict[str, int]: + return { + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + @staticmethod + def _format_related_counts(counts: dict[str, int]) -> str: + return ( + f"node_executions {counts['node_executions']}, " + f"offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, " + f"trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, " + f"pause_reasons {counts['pause_reasons']}" + ) + + def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_contexts = self._build_run_contexts(runs) + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.count_by_runs(session, run_contexts) + + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_contexts = self._build_run_contexts(runs) + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 0c34676252..d443c4c9a5 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -4,6 +4,7 @@ from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker from core.workflow.enums import WorkflowExecutionStatus @@ -104,6 +105,42 @@ class TestDifyAPISQLAlchemyWorkflowRunRepository: return pause +class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_get_runs_batch_by_time_range_filters_terminal_statuses( + self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock + ): + scalar_result = Mock() + scalar_result.all.return_value = [] + mock_session.scalars.return_value = scalar_result + + repository.get_runs_batch_by_time_range( + start_from=None, + end_before=datetime(2024, 1, 1), + last_seen=None, + batch_size=50, + ) + + stmt = mock_session.scalars.call_args[0][0] + compiled_sql = str( + stmt.compile( + dialect=postgresql.dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + + assert "workflow_runs.status" in compiled_sql + for status in ( + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.STOPPED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + ): + assert f"'{status.value}'" in compiled_sql + + assert "'running'" not in compiled_sql + assert "'paused'" not in compiled_sql + + class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test create_workflow_pause method.""" @@ -181,6 +218,61 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ) +class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): + node_ids_result = Mock() + node_ids_result.all.return_value = [] + pause_ids_result = Mock() + pause_ids_result.all.return_value = [] + mock_session.scalars.side_effect = [node_ids_result, pause_ids_result] + + # app_logs delete, runs delete + mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)] + + fake_trigger_repo = Mock() + fake_trigger_repo.delete_by_run_ids.return_value = 3 + + run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") + counts = repository.delete_runs_with_related( + [run], + delete_node_executions=lambda session, runs: (2, 1), + delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids), + ) + + fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["runs"] == 1 + + +class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): + pause_ids_result = Mock() + pause_ids_result.all.return_value = ["pause-1", "pause-2"] + mock_session.scalars.return_value = pause_ids_result + mock_session.scalar.side_effect = [5, 2] + + fake_trigger_repo = Mock() + fake_trigger_repo.count_by_run_ids.return_value = 3 + + run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") + counts = repository.count_runs_with_related( + [run], + count_node_executions=lambda session, runs: (2, 1), + count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids), + ) + + fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["app_logs"] == 5 + assert counts["pauses"] == 2 + assert counts["pause_reasons"] == 2 + assert counts["runs"] == 1 + + class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test resume_workflow_pause method.""" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py new file mode 100644 index 0000000000..d409618211 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py @@ -0,0 +1,31 @@ +from unittest.mock import Mock + +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import Session + +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +def test_delete_by_run_ids_executes_delete(): + session = Mock(spec=Session) + session.execute.return_value = Mock(rowcount=2) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids(["run-1", "run-2"]) + + stmt = session.execute.call_args[0][0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) + assert "workflow_trigger_logs" in compiled_sql + assert "'run-1'" in compiled_sql + assert "'run-2'" in compiled_sql + assert deleted == 2 + + +def test_delete_by_run_ids_empty_short_circuits(): + session = Mock(spec=Session) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids([]) + + session.execute.assert_not_called() + assert deleted == 0 diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..8c80e2b4ad --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,327 @@ +import datetime +from typing import Any + +import pytest + +from services.billing_service import SubscriptionPlan +from services.retention.workflow_run import clear_free_plan_expired_workflow_run_logs as cleanup_module +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +class FakeRun: + def __init__( + self, + run_id: str, + tenant_id: str, + created_at: datetime.datetime, + app_id: str = "app-1", + workflow_id: str = "wf-1", + triggered_from: str = "workflow-run", + ) -> None: + self.id = run_id + self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.triggered_from = triggered_from + self.created_at = created_at + + +class FakeRepo: + def __init__( + self, + batches: list[list[FakeRun]], + delete_result: dict[str, int] | None = None, + count_result: dict[str, int] | None = None, + ) -> None: + self.batches = batches + self.call_idx = 0 + self.deleted: list[list[str]] = [] + self.counted: list[list[str]] = [] + self.delete_result = delete_result or { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + self.count_result = count_result or { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + def get_runs_batch_by_time_range( + self, + start_from: datetime.datetime | None, + end_before: datetime.datetime, + last_seen: tuple[datetime.datetime, str] | None, + batch_size: int, + ) -> list[FakeRun]: + if self.call_idx >= len(self.batches): + return [] + batch = self.batches[self.call_idx] + self.call_idx += 1 + return batch + + def delete_runs_with_related( + self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None + ) -> dict[str, int]: + self.deleted.append([run.id for run in runs]) + result = self.delete_result.copy() + result["runs"] = len(runs) + return result + + def count_runs_with_related( + self, runs: list[FakeRun], count_node_executions=None, count_trigger_logs=None + ) -> dict[str, int]: + self.counted.append([run.id for run in runs]) + result = self.count_result.copy() + result["runs"] = len(runs) + return result + + +def plan_info(plan: str, expiration: int) -> SubscriptionPlan: + return SubscriptionPlan(plan=plan, expiration_date=expiration) + + +def create_cleanup( + monkeypatch: pytest.MonkeyPatch, + repo: FakeRepo, + *, + grace_period_days: int = 0, + whitelist: set[str] | None = None, + **kwargs: Any, +) -> WorkflowRunCleanup: + monkeypatch.setattr( + cleanup_module.dify_config, + "SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD", + grace_period_days, + ) + monkeypatch.setattr( + cleanup_module.WorkflowRunCleanup, + "_get_cleanup_whitelist", + lambda self: whitelist or set(), + ) + return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs) + + +def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + def fail_bulk(_: list[str]) -> dict[str, SubscriptionPlan]: + raise RuntimeError("should not call") + + monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fail_bulk)) + + tenants = {"t1", "t2"} + free = cleanup._filter_free_tenants(tenants) + + assert free == tenants + + +def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"}) + + assert free == {"t_free", "t_missing"} + + +def test_filter_free_tenants_respects_grace_period(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, grace_period_days=45) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + now = datetime.datetime.now(datetime.UTC) + within_grace_ts = int((now - datetime.timedelta(days=10)).timestamp()) + outside_grace_ts = int((now - datetime.timedelta(days=90)).timestamp()) + + def fake_bulk(_: list[str]) -> dict[str, SubscriptionPlan]: + return { + "recently_downgraded": plan_info("sandbox", within_grace_ts), + "long_sandbox": plan_info("sandbox", outside_grace_ts), + } + + monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fake_bulk)) + + free = cleanup._filter_free_tenants({"recently_downgraded", "long_sandbox"}) + + assert free == {"long_sandbox"} + + +def test_filter_free_tenants_skips_cleanup_whitelist(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup( + monkeypatch, + repo=FakeRepo([]), + days=30, + batch_size=10, + whitelist={"tenant_whitelist"}, + ) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + tenants = {"tenant_whitelist", "tenant_regular"} + free = cleanup._filter_free_tenants(tenants) + + assert free == {"tenant_regular"} + + +def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod(lambda tenant_ids: (_ for _ in ()).throw(RuntimeError("boom"))), + ) + + free = cleanup._filter_free_tenants({"t1", "t2"}) + + assert free == set() + + +def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[ + [ + FakeRun("run-free", "t_free", cutoff), + FakeRun("run-paid", "t_paid", cutoff), + ] + ] + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + cleanup.run() + + assert repo.deleted == [["run-free"]] + + +def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]]) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod(lambda tenant_ids: {tenant_id: plan_info("team", 1893456000) for tenant_id in tenant_ids}), + ) + + cleanup.run() + + assert repo.deleted == [] + + +def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + cleanup.run() + + +def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[[FakeRun("run-free", "t_free", cutoff)]], + count_result={ + "runs": 0, + "node_executions": 2, + "offloads": 1, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + }, + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10, dry_run=True) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + cleanup.run() + + assert repo.deleted == [] + assert repo.counted == [["run-free"]] + captured = capsys.readouterr().out + assert "Dry run mode enabled" in captured + assert "would delete 1 runs" in captured + assert "related records" in captured + assert "node_executions 2" in captured + assert "offloads 1" in captured + assert "app_logs 3" in captured + assert "trigger_logs 4" in captured + assert "pauses 5" in captured + assert "pause_reasons 6" in captured + + +def test_between_sets_window_bounds(monkeypatch: pytest.MonkeyPatch) -> None: + start_from = datetime.datetime(2024, 5, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 6, 1, 0, 0, 0) + cleanup = create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before + ) + + assert cleanup.window_start == start_from + assert cleanup.window_end == end_before + + +def test_between_requires_both_boundaries(monkeypatch: pytest.MonkeyPatch) -> None: + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=datetime.datetime.now(), end_before=None + ) + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=None, end_before=datetime.datetime.now() + ) + + +def test_between_requires_end_after_start(monkeypatch: pytest.MonkeyPatch) -> None: + start_from = datetime.datetime(2024, 6, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 5, 1, 0, 0, 0) + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before + ) diff --git a/docker/.env.example b/docker/.env.example index 09ee1060e2..e7cb8711ce 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1478,6 +1478,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 712de84c62..041f60aaa2 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -662,6 +662,7 @@ x-shared-env: &shared-api-worker-env ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false} ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-false} ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-false} + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: ${ENABLE_WORKFLOW_RUN_CLEANUP_TASK:-false} ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false} ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false} ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true} From b63dfbf654ed2191aa94aa4bcf8d09c09cc90759 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 12 Jan 2026 16:23:18 +0800 Subject: [PATCH 04/29] fix(api): defer streaming response until referenced variables are updated (#30832) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../nodes/variable_assigner/v1/node.py | 9 + ...ng_conversation_variables_v1_overwrite.yml | 158 ++++++++++++++++++ .../test_streaming_conversation_variables.py | 30 ++++ 3 files changed, 197 insertions(+) create mode 100644 api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index d2ea7d94ea..ac2870aa65 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]): graph_runtime_state=graph_runtime_state, ) + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: + """ + Check if this Variable Assigner node blocks the output of specific variables. + + Returns True if this node updates any of the requested conversation variables. + """ + assigned_selector = tuple(self.node_data.assigned_variable_selector) + return assigned_selector in variable_selectors + @classmethod def version(cls) -> str: return "1" diff --git a/api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml b/api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml new file mode 100644 index 0000000000..5b2a6260e9 --- /dev/null +++ b/api/tests/fixtures/workflow/test_streaming_conversation_variables_v1_overwrite.yml @@ -0,0 +1,158 @@ +app: + description: Validate v1 Variable Assigner blocks streaming until conversation variable is updated. + icon: 🤖 + icon_background: '#FFEAD5' + mode: advanced-chat + name: test_streaming_conversation_variables_v1_overwrite + use_icon_as_answer_icon: false +dependencies: [] +kind: app +version: 0.5.0 +workflow: + conversation_variables: + - description: '' + id: 6ddf2d7f-3d1b-4bb0-9a5e-9b0c87c7b5e6 + name: conv_var + selector: + - conversation + - conv_var + value: default + value_type: string + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: assigner + id: start-source-assigner-target + source: start + sourceHandle: source + target: assigner + targetHandle: target + type: custom + zIndex: 0 + - data: + isInLoop: false + sourceType: assigner + targetType: answer + id: assigner-source-answer-target + source: assigner + sourceHandle: source + target: answer + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + desc: '' + selected: false + title: Start + type: start + variables: [] + height: 54 + id: start + position: + x: 30 + y: 253 + positionAbsolute: + x: 30 + y: 253 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + answer: 'Current Value Of `conv_var` is:{{#conversation.conv_var#}}' + desc: '' + selected: false + title: Answer + type: answer + variables: [] + height: 106 + id: answer + position: + x: 638 + y: 253 + positionAbsolute: + x: 638 + y: 253 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 244 + - data: + assigned_variable_selector: + - conversation + - conv_var + desc: '' + input_variable_selector: + - sys + - query + selected: false + title: Variable Assigner + type: assigner + write_mode: over-write + height: 84 + id: assigner + position: + x: 334 + y: 253 + positionAbsolute: + x: 334 + y: 253 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 244 + viewport: + x: 0 + y: 0 + zoom: 0.7 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 1f4c063bf0..99157a7c3e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -45,3 +45,33 @@ def test_streaming_conversation_variables(): runner = TableTestRunner() result = runner.run_test_case(case) assert result.success, f"Test failed: {result.error}" + + +def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment(): + fixture_name = "test_streaming_conversation_variables_v1_overwrite" + input_query = "overwrite-value" + + case = WorkflowTestCase( + fixture_path=fixture_name, + use_auto_mock=False, + mock_config=MockConfigBuilder().build(), + query=input_query, + inputs={}, + expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"}, + ) + + runner = TableTestRunner() + result = runner.run_test_case(case) + assert result.success, f"Test failed: {result.error}" + + events = result.events + conv_var_chunk_events = [ + event + for event in events + if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var") + ] + + assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted" + assert all(event.chunk == input_query for event in conv_var_chunk_events), ( + "Expected streamed conversation variable value to match the input query" + ) From 837237aa6d00582aee2621f96974df306239d779 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 13 Jan 2026 10:11:18 +0800 Subject: [PATCH 05/29] fix: use node factory for single-step workflow nodes (#30859) --- api/core/workflow/workflow_entry.py | 13 ++--- .../core/workflow/test_workflow_entry.py | 56 +++++++++++++++++++ 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ddf545bb34..fd3fc02f62 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node +from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -136,13 +137,11 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - node_config = workflow.get_node_config_by_id(node_id) + node_config = dict(workflow.get_node_config_by_id(node_id)) node_config_data = node_config.get("data", {}) - # Get node class + # Get node type node_type = NodeType(node_config_data.get("type")) - node_version = node_config_data.get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -158,12 +157,12 @@ class WorkflowEntry: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node = node_cls( - id=str(uuid.uuid4()), - config=node_config, + node_factory = DifyNodeFactory( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) + node = node_factory.create_node(node_config) + node_cls = type(node) try: # variable selector to variable mapping diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 68d6c109e8..b38e070ffc 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -2,13 +2,17 @@ from types import SimpleNamespace import pytest +from configs import dify_config from core.file.enums import FileType from core.file.models import File, FileTransferMethod +from core.helper.code_executor.code_executor import CodeLanguage from core.variables.variables import StringVariable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry @@ -96,6 +100,58 @@ class TestWorkflowEntry: assert output_var is not None assert output_var.value == "system_user" + def test_single_step_run_injects_code_limits(self): + """Ensure single-step CodeNode execution configures limits.""" + # Arrange + node_id = "code_node" + node_data = { + "type": "code", + "title": "Code", + "desc": None, + "variables": [], + "code_language": CodeLanguage.PYTHON3, + "code": "def main():\n return {}", + "outputs": {}, + } + node_config = {"id": node_id, "data": node_data} + + class StubWorkflow: + def __init__(self): + self.tenant_id = "tenant" + self.app_id = "app" + self.id = "workflow" + self.graph_dict = {"nodes": [node_config], "edges": []} + + def get_node_config_by_id(self, target_id: str): + assert target_id == node_id + return node_config + + workflow = StubWorkflow() + variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={}) + expected_limits = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ) + + # Act + node, _ = WorkflowEntry.single_step_run( + workflow=workflow, + node_id=node_id, + user_id="user", + user_inputs={}, + variable_pool=variable_pool, + ) + + # Assert + assert isinstance(node, CodeNode) + assert node._limits == expected_limits + def test_mapping_user_inputs_to_variable_pool_with_env_variables(self): """Test mapping environment variables from user inputs to variable pool.""" # Initialize variable pool with environment variables From 450578d4c0db3b24ee0a9bc4904afd6b85ee1008 Mon Sep 17 00:00:00 2001 From: heyszt <270985384@qq.com> Date: Tue, 13 Jan 2026 10:12:00 +0800 Subject: [PATCH 06/29] feat(ops): set root span kind for AliyunTrace to enable service-level metrics aggregation (#30728) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/ops/aliyun_trace/aliyun_trace.py | 4 ++++ api/core/ops/aliyun_trace/data_exporter/traceclient.py | 2 +- api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py | 3 ++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index d6bd4d2015..cf6659150f 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,7 @@ import logging from collections.abc import Sequence +from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( @@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance): ), status=status, links=trace_metadata.links, + span_kind=SpanKind.SERVER, ) self.trace_client.add_span(message_span) @@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance): ), status=status, links=trace_metadata.links, + span_kind=SpanKind.SERVER, ) self.trace_client.add_span(message_span) @@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance): ), status=status, links=trace_metadata.links, + span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL, ) self.trace_client.add_span(workflow_span) diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index d3324f8f82..7624586367 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -166,7 +166,7 @@ class SpanBuilder: attributes=span_data.attributes, events=span_data.events, links=span_data.links, - kind=trace_api.SpanKind.INTERNAL, + kind=span_data.span_kind, status=span_data.status, start_time=span_data.start_time, end_time=span_data.end_time, diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py index 20ff2d0875..9078031490 100644 --- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py +++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py @@ -4,7 +4,7 @@ from typing import Any from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event -from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import BaseModel, Field @@ -34,3 +34,4 @@ class SpanData(BaseModel): status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") start_time: int | None = Field(..., description="The start time of the span in nanoseconds.") end_time: int | None = Field(..., description="The end time of the span in nanoseconds.") + span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.") From a012c87445bb0bdd7cb2f5f6069e7b595340a958 Mon Sep 17 00:00:00 2001 From: hsiong <37357447+hsiong@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:12:51 +0800 Subject: [PATCH 07/29] fix: entrypoint.sh overrides NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS when TEXT_GENERATION_TIMEOUT_MS is unset (#30864) (#30865) --- web/.env.example | 2 ++ 1 file changed, 2 insertions(+) diff --git a/web/.env.example b/web/.env.example index f2f25454cb..df4e725c51 100644 --- a/web/.env.example +++ b/web/.env.example @@ -31,6 +31,8 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false # The timeout for the text generation in millisecond NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 +# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS at container startup (Docker only) +TEXT_GENERATION_TIMEOUT_MS=60000 # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP NEXT_PUBLIC_CSP_WHITELIST= From 9ee71902c1527a2d2654a3aa2fd1e5e1f4173450 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 13 Jan 2026 11:51:15 +0800 Subject: [PATCH 08/29] fix: fix formatNumber accuracy (#30877) --- web/utils/format.spec.ts | 22 ++++++++++++++++++++++ web/utils/format.ts | 30 +++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/web/utils/format.spec.ts b/web/utils/format.spec.ts index 0fde0ccbe8..3a1709dbdc 100644 --- a/web/utils/format.spec.ts +++ b/web/utils/format.spec.ts @@ -19,6 +19,28 @@ describe('formatNumber', () => { it('should correctly handle empty input', () => { expect(formatNumber('')).toBe('') }) + it('should format very small numbers without scientific notation', () => { + expect(formatNumber(0.0000008)).toBe('0.0000008') + expect(formatNumber(0.0000001)).toBe('0.0000001') + expect(formatNumber(0.000001)).toBe('0.000001') + expect(formatNumber(0.00001)).toBe('0.00001') + }) + it('should format negative small numbers without scientific notation', () => { + expect(formatNumber(-0.0000008)).toBe('-0.0000008') + expect(formatNumber(-0.0000001)).toBe('-0.0000001') + }) + it('should handle small numbers from string input', () => { + expect(formatNumber('0.0000008')).toBe('0.0000008') + expect(formatNumber('8E-7')).toBe('0.0000008') + expect(formatNumber('1e-7')).toBe('0.0000001') + }) + it('should handle small numbers with multi-digit mantissa in scientific notation', () => { + expect(formatNumber(1.23e-7)).toBe('0.000000123') + expect(formatNumber(1.234e-7)).toBe('0.0000001234') + expect(formatNumber(12.34e-7)).toBe('0.000001234') + expect(formatNumber(0.0001234)).toBe('0.0001234') + expect(formatNumber('1.23e-7')).toBe('0.000000123') + }) }) describe('formatFileSize', () => { it('should return the input if it is falsy', () => { diff --git a/web/utils/format.ts b/web/utils/format.ts index d087d690a2..0c81b339a3 100644 --- a/web/utils/format.ts +++ b/web/utils/format.ts @@ -26,11 +26,39 @@ import 'dayjs/locale/zh-tw' * Formats a number with comma separators. * @example formatNumber(1234567) will return '1,234,567' * @example formatNumber(1234567.89) will return '1,234,567.89' + * @example formatNumber(0.0000008) will return '0.0000008' */ export const formatNumber = (num: number | string) => { if (!num) return num - const parts = num.toString().split('.') + const n = typeof num === 'string' ? Number(num) : num + + let numStr: string + + // Force fixed decimal for small numbers to avoid scientific notation + if (Math.abs(n) < 0.001 && n !== 0) { + const str = n.toString() + const match = str.match(/e-(\d+)$/) + let precision: number + if (match) { + // Scientific notation: precision is exponent + decimal digits in mantissa + const exponent = Number.parseInt(match[1], 10) + const mantissa = str.split('e')[0] + const mantissaDecimalPart = mantissa.split('.')[1] + precision = exponent + (mantissaDecimalPart?.length || 0) + } + else { + // Decimal notation: count decimal places + const decimalPart = str.split('.')[1] + precision = decimalPart?.length || 0 + } + numStr = n.toFixed(precision) + } + else { + numStr = n.toString() + } + + const parts = numStr.split('.') parts[0] = parts[0].replace(/\B(?=(\d{3})+(?!\d))/g, ',') return parts.join('.') } From 8f43629cd84f4cebf47ba84d6d018825853483ec Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Tue, 13 Jan 2026 12:26:50 +0800 Subject: [PATCH 09/29] fix(amplitude): update sessionReplaySampleRate default value to 0.5 (#30880) Co-authored-by: CodingOnStar --- web/app/components/base/amplitude/AmplitudeProvider.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/base/amplitude/AmplitudeProvider.tsx b/web/app/components/base/amplitude/AmplitudeProvider.tsx index 87ef516835..0f083a4a7d 100644 --- a/web/app/components/base/amplitude/AmplitudeProvider.tsx +++ b/web/app/components/base/amplitude/AmplitudeProvider.tsx @@ -54,7 +54,7 @@ const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => { } const AmplitudeProvider: FC = ({ - sessionReplaySampleRate = 1, + sessionReplaySampleRate = 0.5, }) => { useEffect(() => { // Only enable in Saas edition with valid API key From 9be863fefa03456508f7abd605d5cb78ca5a5bb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Tue, 13 Jan 2026 12:46:33 +0800 Subject: [PATCH 10/29] fix: missing content if assistant message with tool_calls (#30083) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/agent/fc_agent_runner.py | 4 +--- api/core/model_runtime/entities/message_entities.py | 5 +---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 68d14ad027..7c5c9136a7 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): ), ) - assistant_message = AssistantPromptMessage(content="", tool_calls=[]) + assistant_message = AssistantPromptMessage(content=response, tool_calls=[]) if tool_calls: assistant_message.tool_calls = [ AssistantPromptMessage.ToolCall( @@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): ) for tool_call in tool_calls ] - else: - assistant_message.content = response self._current_thoughts.append(assistant_message) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 3ac83b4c96..9e46d72893 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage): :return: True if prompt message is empty, False otherwise """ - if not super().is_empty() and not self.tool_calls: - return False - - return True + return super().is_empty() and not self.tool_calls class SystemPromptMessage(PromptMessage): From 2d53ba86710d63cfc368684fde8fcca0e7a77f76 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 13 Jan 2026 15:21:06 +0800 Subject: [PATCH 11/29] fix: fix object value is optional should skip validate (#30894) --- api/core/app/apps/base_app_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index e4486e892c..07bae66867 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -189,7 +189,7 @@ class BaseAppGenerator: elif value == 0: value = False case VariableEntityType.JSON_OBJECT: - if not isinstance(value, dict): + if value and not isinstance(value, dict): raise ValueError(f"{variable_entity.variable} in input form must be a dict") case _: raise AssertionError("this statement should be unreachable.") From c09e29c3f83c7400e2b4d98730d388321681708b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 13 Jan 2026 15:26:41 +0800 Subject: [PATCH 12/29] chore: rename the migration file (#30893) --- ...01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename api/migrations/versions/{2026_01_09_1630-905527cc8fd3_.py => 2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py} (100%) diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py similarity index 100% rename from api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py rename to api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py From ea708e7a32542614bb852a05146e3c170ce29125 Mon Sep 17 00:00:00 2001 From: lif <1835304752@qq.com> Date: Tue, 13 Jan 2026 15:40:43 +0800 Subject: [PATCH 13/29] fix(web): add null check for SSE stream bufferObj to prevent TypeError (#30131) Signed-off-by: majiayu000 <1835304752@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 --- web/service/base.spec.ts | 231 +++++++++++++++++++++++++++++++++++++++ web/service/base.ts | 11 ++ 2 files changed, 242 insertions(+) create mode 100644 web/service/base.spec.ts diff --git a/web/service/base.spec.ts b/web/service/base.spec.ts new file mode 100644 index 0000000000..d6ed242ed9 --- /dev/null +++ b/web/service/base.spec.ts @@ -0,0 +1,231 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { handleStream } from './base' + +describe('handleStream', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Invalid response data handling', () => { + it('should handle null bufferObj from JSON.parse gracefully', async () => { + // Arrange + const onData = vi.fn() + const onCompleted = vi.fn() + + // Create a mock response that returns 'data: null' + const mockReader = { + read: vi.fn() + .mockResolvedValueOnce({ + done: false, + value: new TextEncoder().encode('data: null\n'), + }) + .mockResolvedValueOnce({ + done: true, + value: undefined, + }), + } + + const mockResponse = { + ok: true, + body: { + getReader: () => mockReader, + }, + } as unknown as Response + + // Act + handleStream(mockResponse, onData, onCompleted) + + // Wait for the stream to be processed + await new Promise(resolve => setTimeout(resolve, 50)) + + // Assert + expect(onData).toHaveBeenCalledWith('', true, { + conversationId: undefined, + messageId: '', + errorMessage: 'Invalid response data', + errorCode: 'invalid_data', + }) + expect(onCompleted).toHaveBeenCalledWith(true, 'Invalid response data') + }) + + it('should handle non-object bufferObj from JSON.parse gracefully', async () => { + // Arrange + const onData = vi.fn() + const onCompleted = vi.fn() + + // Create a mock response that returns a primitive value + const mockReader = { + read: vi.fn() + .mockResolvedValueOnce({ + done: false, + value: new TextEncoder().encode('data: "string"\n'), + }) + .mockResolvedValueOnce({ + done: true, + value: undefined, + }), + } + + const mockResponse = { + ok: true, + body: { + getReader: () => mockReader, + }, + } as unknown as Response + + // Act + handleStream(mockResponse, onData, onCompleted) + + // Wait for the stream to be processed + await new Promise(resolve => setTimeout(resolve, 50)) + + // Assert + expect(onData).toHaveBeenCalledWith('', true, { + conversationId: undefined, + messageId: '', + errorMessage: 'Invalid response data', + errorCode: 'invalid_data', + }) + expect(onCompleted).toHaveBeenCalledWith(true, 'Invalid response data') + }) + + it('should handle valid message event correctly', async () => { + // Arrange + const onData = vi.fn() + const onCompleted = vi.fn() + + const validMessage = { + event: 'message', + answer: 'Hello world', + conversation_id: 'conv-123', + task_id: 'task-456', + id: 'msg-789', + } + + const mockReader = { + read: vi.fn() + .mockResolvedValueOnce({ + done: false, + value: new TextEncoder().encode(`data: ${JSON.stringify(validMessage)}\n`), + }) + .mockResolvedValueOnce({ + done: true, + value: undefined, + }), + } + + const mockResponse = { + ok: true, + body: { + getReader: () => mockReader, + }, + } as unknown as Response + + // Act + handleStream(mockResponse, onData, onCompleted) + + // Wait for the stream to be processed + await new Promise(resolve => setTimeout(resolve, 50)) + + // Assert + expect(onData).toHaveBeenCalledWith('Hello world', true, { + conversationId: 'conv-123', + taskId: 'task-456', + messageId: 'msg-789', + }) + expect(onCompleted).toHaveBeenCalled() + }) + + it('should handle error status 400 correctly', async () => { + // Arrange + const onData = vi.fn() + const onCompleted = vi.fn() + + const errorMessage = { + status: 400, + message: 'Bad request', + code: 'bad_request', + } + + const mockReader = { + read: vi.fn() + .mockResolvedValueOnce({ + done: false, + value: new TextEncoder().encode(`data: ${JSON.stringify(errorMessage)}\n`), + }) + .mockResolvedValueOnce({ + done: true, + value: undefined, + }), + } + + const mockResponse = { + ok: true, + body: { + getReader: () => mockReader, + }, + } as unknown as Response + + // Act + handleStream(mockResponse, onData, onCompleted) + + // Wait for the stream to be processed + await new Promise(resolve => setTimeout(resolve, 50)) + + // Assert + expect(onData).toHaveBeenCalledWith('', false, { + conversationId: undefined, + messageId: '', + errorMessage: 'Bad request', + errorCode: 'bad_request', + }) + expect(onCompleted).toHaveBeenCalledWith(true, 'Bad request') + }) + + it('should handle malformed JSON gracefully', async () => { + // Arrange + const onData = vi.fn() + const onCompleted = vi.fn() + + const mockReader = { + read: vi.fn() + .mockResolvedValueOnce({ + done: false, + value: new TextEncoder().encode('data: {invalid json}\n'), + }) + .mockResolvedValueOnce({ + done: true, + value: undefined, + }), + } + + const mockResponse = { + ok: true, + body: { + getReader: () => mockReader, + }, + } as unknown as Response + + // Act + handleStream(mockResponse, onData, onCompleted) + + // Wait for the stream to be processed + await new Promise(resolve => setTimeout(resolve, 50)) + + // Assert - malformed JSON triggers the catch block which calls onData and returns + expect(onData).toHaveBeenCalled() + expect(onCompleted).toHaveBeenCalled() + }) + + it('should throw error when response is not ok', () => { + // Arrange + const onData = vi.fn() + const mockResponse = { + ok: false, + } as unknown as Response + + // Act & Assert + expect(() => handleStream(mockResponse, onData)).toThrow('Network response was not ok') + }) + }) +}) diff --git a/web/service/base.ts b/web/service/base.ts index d9f3dba53a..2ab115f96c 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -217,6 +217,17 @@ export const handleStream = ( }) return } + if (!bufferObj || typeof bufferObj !== 'object') { + onData('', isFirstMessage, { + conversationId: undefined, + messageId: '', + errorMessage: 'Invalid response data', + errorCode: 'invalid_data', + }) + hasError = true + onCompleted?.(true, 'Invalid response data') + return + } if (bufferObj.status === 400 || !bufferObj.event) { onData('', false, { conversationId: undefined, From 0e33dfb5c20c450b6e54ca186995d414bd613d1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E6=9E=95?= <13218909716@139.com> Date: Tue, 13 Jan 2026 15:42:32 +0800 Subject: [PATCH 14/29] =?UTF-8?q?fix:=20In=20the=20LLM=20model=20in=20dify?= =?UTF-8?q?,=20when=20a=20message=20is=20added,=20the=20first=20cli?= =?UTF-8?q?=E2=80=A6=20(#29540)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 青枕 --- .../workflow/nodes/llm/components/config-prompt.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/workflow/nodes/llm/components/config-prompt.tsx b/web/app/components/workflow/nodes/llm/components/config-prompt.tsx index 0cddb15ab6..5b28c9b48f 100644 --- a/web/app/components/workflow/nodes/llm/components/config-prompt.tsx +++ b/web/app/components/workflow/nodes/llm/components/config-prompt.tsx @@ -106,12 +106,12 @@ const ConfigPrompt: FC = ({ const handleAddPrompt = useCallback(() => { const newPrompt = produce(payload as PromptItem[], (draft) => { if (draft.length === 0) { - draft.push({ role: PromptRole.system, text: '' }) + draft.push({ role: PromptRole.system, text: '', id: uuid4() }) return } const isLastItemUser = draft[draft.length - 1].role === PromptRole.user - draft.push({ role: isLastItemUser ? PromptRole.assistant : PromptRole.user, text: '' }) + draft.push({ role: isLastItemUser ? PromptRole.assistant : PromptRole.user, text: '', id: uuid4() }) }) onChange(newPrompt) }, [onChange, payload]) From 491e1fd6a4a1dba36459d92411120e5e2defb2a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 13 Jan 2026 15:42:44 +0800 Subject: [PATCH 15/29] chore: case insensitive email (#29978) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- --- api/commands.py | 18 +- api/controllers/console/auth/activate.py | 8 +- .../console/auth/email_register.py | 33 +-- .../console/auth/forgot_password.py | 29 +- api/controllers/console/auth/login.py | 70 +++-- api/controllers/console/auth/oauth.py | 17 +- api/controllers/console/setup.py | 3 +- api/controllers/console/workspace/account.py | 49 ++-- api/controllers/console/workspace/members.py | 15 +- api/controllers/web/forgot_password.py | 31 ++- api/controllers/web/login.py | 12 +- api/services/account_service.py | 41 ++- api/services/webapp_auth_service.py | 5 +- .../console/auth/test_account_activation.py | 66 ++++- .../auth/test_authentication_security.py | 6 +- .../console/auth/test_email_register.py | 177 +++++++++++++ .../console/auth/test_forgot_password.py | 176 +++++++++++++ .../console/auth/test_login_logout.py | 62 ++++- .../controllers/console/auth/test_oauth.py | 86 +++++- .../console/auth/test_password_reset.py | 104 +++++--- .../controllers/console/test_setup.py | 39 +++ .../console/test_workspace_account.py | 247 ++++++++++++++++++ .../console/test_workspace_members.py | 82 ++++++ .../controllers/web/test_forgot_password.py | 195 -------------- .../web/test_web_forgot_password.py | 226 ++++++++++++++++ .../controllers/web/test_web_login.py | 91 +++++++ .../services/test_account_service.py | 99 ++++++- 27 files changed, 1611 insertions(+), 376 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/auth/test_email_register.py create mode 100644 api/tests/unit_tests/controllers/console/auth/test_forgot_password.py create mode 100644 api/tests/unit_tests/controllers/console/test_setup.py create mode 100644 api/tests/unit_tests/controllers/console/test_workspace_account.py create mode 100644 api/tests/unit_tests/controllers/console/test_workspace_members.py delete mode 100644 api/tests/unit_tests/controllers/web/test_forgot_password.py create mode 100644 api/tests/unit_tests/controllers/web/test_web_forgot_password.py create mode 100644 api/tests/unit_tests/controllers/web/test_web_login.py diff --git a/api/commands.py b/api/commands.py index e24b1826ee..20ce22a6c7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -35,7 +35,7 @@ from libs.rsa import generate_key_pair from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile +from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.provider import Provider, ProviderModel from models.provider_ids import DatasourceProviderID, ToolProviderID @@ -64,8 +64,10 @@ def reset_password(email, new_password, password_confirm): if str(new_password).strip() != str(password_confirm).strip(): click.echo(click.style("Passwords do not match.", fg="red")) return + normalized_email = email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -86,7 +88,7 @@ def reset_password(email, new_password, password_confirm): base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -102,20 +104,22 @@ def reset_email(email, new_email, email_confirm): if str(new_email).strip() != str(email_confirm).strip(): click.echo(click.style("New emails do not match.", fg="red")) return + normalized_new_email = new_email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) return try: - email_validate(new_email) + email_validate(normalized_new_email) except: click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account.email = new_email + account.email = normalized_new_email click.echo(click.style("Email updated successfully.", fg="green")) @@ -660,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No return # Create account - email = email.strip() + email = email.strip().lower() if "@" not in email: click.echo(click.style("Invalid email address.", fg="red")) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index fe70d930fb..cfc673880c 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -63,10 +63,9 @@ class ActivateCheckApi(Resource): args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore workspaceId = args.workspace_id - reg_email = args.email token = args.token - invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) + invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token) if invitation: data = invitation.get("data", {}) tenant = invitation.get("tenant", None) @@ -100,11 +99,12 @@ class ActivateApi(Resource): def post(self): args = ActivatePayload.model_validate(console_ns.payload) - invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token) + normalized_request_email = args.email.lower() if args.email else None + invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args.workspace_id, args.email, args.token) + RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token) account = invitation["account"] account.name = args.name diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index fa082c735d..c2a95ddad2 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource): @email_register_enabled def post(self): args = EmailRegisterSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource): if args.language in languages: language = args.language - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() - token = None - token = AccountService.send_email_register_email(email=args.email, account=account, language=language) + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource): def post(self): args = EmailRegisterValidityPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email) + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email) if is_email_register_error_rate_limit: raise EmailRegisterLimitError() @@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_email_register_error_rate_limit(args.email) + AccountService.add_email_register_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource): user_email, code=args.code, additional_data={"phase": "register"} ) - AccountService.reset_email_register_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_email_register_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/email-register") @@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource): AccountService.revoke_email_register_token(args.token) email = register_data.get("email", "") + normalized_email = email.lower() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: raise EmailAlreadyInUseError() else: - account = self._create_new_account(email, args.password_confirm) + account = self._create_new_account(normalized_email, args.password_confirm) if not account: raise AccountNotFoundError() token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} - def _create_new_account(self, email, password) -> Account | None: + def _create_new_account(self, email: str, password: str) -> Account | None: # Create new account if allowed account = None try: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 661f591182..394f205d93 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.console import console_ns @@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account from services.account_service import AccountService, TenantService from services.feature_service import FeatureService @@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource): @email_password_login_enabled def post(self): args = ForgotPasswordSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( account=account, - email=args.email, + email=normalized_email, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, ) @@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource): def post(self): args = ForgotPasswordCheckPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args.email) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource): # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=args.code, additional_data={"phase": "reset"} + token_email, code=args.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/forgot-password/resets") @@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 4a52bf8abe..400df138b8 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -90,32 +90,38 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" args = LoginPayload.model_validate(console_ns.payload) + request_email = args.email + normalized_email = request_email.lower() - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email) + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() + invite_token = args.invite_token invitation_data: dict[str, Any] | None = None - if args.invite_token: - invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token) + if invite_token: + invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) + if invitation_data is None: + invite_token = None try: if invitation_data: data = invitation_data.get("data", {}) invitee_email = data.get("email") if data else None - if invitee_email != args.email: + invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email + if invitee_email_normalized != normalized_email: raise InvalidEmailError() - account = AccountService.authenticate(args.email, args.password, args.invite_token) - else: - account = AccountService.authenticate(args.email, args.password) + account = _authenticate_account_with_case_fallback( + request_email, normalized_email, args.password, invite_token + ) except services.errors.account.AccountLoginError: raise AccountBannedError() - except services.errors.account.AccountPasswordError: - AccountService.add_login_error_rate_limit(args.email) - raise AuthenticationFailedError() + except services.errors.account.AccountPasswordError as exc: + AccountService.add_login_error_rate_limit(normalized_email) + raise AuthenticationFailedError() from exc # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: @@ -130,7 +136,7 @@ class LoginApi(Resource): } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args.email) + AccountService.reset_login_error_rate_limit(normalized_email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource): @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): args = EmailPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" try: - account = AccountService.get_user_through_email(args.email) + account = _get_account_with_case_fallback(args.email) except AccountRegisterError: raise AccountInFreezeError() token = AccountService.send_reset_password_email( - email=args.email, + email=normalized_email, account=account, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, @@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource): @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): args = EmailPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource): else: language = "en-US" try: - account = AccountService.get_user_through_email(args.email) + account = _get_account_with_case_fallback(args.email) except AccountRegisterError: raise AccountInFreezeError() if account is None: if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_email_code_login_email(email=args.email, language=language) + token = AccountService.send_email_code_login_email(email=normalized_email, language=language) else: raise AccountNotFound() else: @@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource): def post(self): args = EmailCodeLoginPayload.model_validate(console_ns.payload) - user_email = args.email + original_email = args.email + user_email = original_email.lower() language = args.language token_data = AccountService.get_email_code_login_data(args.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args.email: + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if normalized_token_email != user_email: raise InvalidEmailError() if token_data["code"] != args.code: @@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource): AccountService.revoke_email_code_login_token(args.token) try: - account = AccountService.get_user_through_email(user_email) + account = _get_account_with_case_fallback(original_email) except AccountRegisterError: raise AccountInFreezeError() if account: @@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource): except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args.email) + AccountService.reset_login_error_rate_limit(user_email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -309,3 +320,22 @@ class RefreshTokenApi(Resource): return response except Exception as e: return {"result": "fail", "message": str(e)}, 401 + + +def _get_account_with_case_fallback(email: str): + account = AccountService.get_user_through_email(email) + if account or email == email.lower(): + return account + + return AccountService.get_user_through_email(email.lower()) + + +def _authenticate_account_with_case_fallback( + original_email: str, normalized_email: str, password: str, invite_token: str | None +): + try: + return AccountService.authenticate(original_email, password, invite_token) + except services.errors.account.AccountPasswordError: + if original_email == normalized_email: + raise + return AccountService.authenticate(normalized_email, password, invite_token) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index c20e83d36f..112e152432 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,7 +3,6 @@ import logging import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -118,7 +117,10 @@ class OAuthCallback(Resource): invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) - if invitation_email != user_info.email: + invitation_email_normalized = ( + invitation_email.lower() if isinstance(invitation_email, str) else invitation_email + ) + if invitation_email_normalized != user_info.email.lower(): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") @@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> if not account: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account @@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, tenant_was_created.send(new_tenant) if not account: + normalized_email = user_info.email.lower() oauth_new_user = True if not FeatureService.get_system_features().is_allow_register: - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountRegisterError( description=( "This email account has been deleted within the past " @@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" account = RegisterService.register( - email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider + email=normalized_email, + name=account_name, + password=None, + open_id=user_info.id, + provider=provider, ) # Set interface language diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 7fa02ae280..ed22ef045d 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -84,10 +84,11 @@ class SetupApi(Resource): raise NotInitValidateError() args = SetupRequestPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() # setup RegisterService.setup( - email=args.email, + email=normalized_email, name=args.name, password=args.password, ip_address=extract_remote_ip(request), diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 03ad0f423b..527aabbc3d 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -41,7 +41,7 @@ from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required -from models import Account, AccountIntegrate, InvitationCode +from models import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource): else: language = "en-US" account = None - user_email = args.email + user_email = None + email_for_sending = args.email.lower() if args.phase is not None and args.phase == "new_email": if args.token is None: raise InvalidTokenError() @@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") - if user_email != current_user.email: + if user_email.lower() != current_user.email.lower(): raise InvalidEmailError() + + user_email = current_user.email else: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) if account is None: raise AccountNotFound() + email_for_sending = account.email + user_email = account.email token = AccountService.send_change_email_email( - account=account, email=args.email, old_email=user_email, language=language, phase=args.phase + account=account, + email=email_for_sending, + old_email=user_email, + language=language, + phase=args.phase, ) return {"result": "success", "data": token} @@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource): payload = console_ns.payload or {} args = ChangeEmailValidityPayload.model_validate(payload) - user_email = args.email + user_email = args.email.lower() - is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email) + is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email) if is_change_email_error_rate_limit: raise EmailChangeLimitError() @@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_change_email_error_rate_limit(args.email) + AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource): user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} ) - AccountService.reset_change_email_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_change_email_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/account/change-email/reset") @@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource): def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) + normalized_new_email = args.new_email.lower() - if AccountService.is_account_in_freeze(args.new_email): + if AccountService.is_account_in_freeze(normalized_new_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args.new_email): + if not AccountService.check_email_unique(normalized_new_email): raise EmailAlreadyInUseError() reset_data = AccountService.get_change_email_data(args.token) @@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource): old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() - if current_user.email != old_email: + if current_user.email.lower() != old_email.lower(): raise AccountNotFound() - updated_account = AccountService.update_account_email(current_user, email=args.new_email) + updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) AccountService.send_change_email_completed_notify_email( - email=args.new_email, + email=normalized_new_email, ) return updated_account @@ -645,8 +657,9 @@ class CheckEmailUnique(Resource): def post(self): payload = console_ns.payload or {} args = CheckEmailUniquePayload.model_validate(payload) - if AccountService.is_account_in_freeze(args.email): + normalized_email = args.email.lower() + if AccountService.is_account_in_freeze(normalized_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args.email): + if not AccountService.check_email_unique(normalized_email): raise EmailAlreadyInUseError() return {"result": "success"} diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 0142e14fb0..e9bd2b8f94 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource): raise WorkspaceMembersLimitExceeded() for invitee_email in invitee_emails: + normalized_invitee_email = invitee_email.lower() try: if not inviter.current_tenant: raise ValueError("No current tenant") token = RegisterService.invite_new_member( - inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + tenant=inviter.current_tenant, + email=invitee_email, + language=interface_language, + role=invitee_role, + inviter=inviter, ) - encoded_invitee_email = parse.quote(invitee_email) + encoded_invitee_email = parse.quote(normalized_invitee_email) invitation_results.append( { "status": "success", - "email": invitee_email, + "email": normalized_invitee_email, "url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}", } ) except AccountAlreadyInTenantError: invitation_results.append( - {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"} ) except Exception as e: - invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) + invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)}) return { "result": "success", diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 690b76655f..91d206f727 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.common.schema import register_schema_models @@ -22,7 +21,7 @@ from controllers.web import web_ns from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account +from models.account import Account from services.account_service import AccountService @@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource): def post(self): payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {}) + request_email = payload.email + normalized_email = request_email.lower() + ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() @@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) token = None if account is None: raise AuthenticationFailedError() else: - token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language) + token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language) return {"result": "success", "data": token} @@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource): def post(self): payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {}) - user_email = payload.email + user_email = payload.email.lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + + if user_email != normalized_token_email: raise InvalidEmailError() if payload.code != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(payload.email) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource): # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=payload.code, additional_data={"phase": "reset"} + token_email, code=payload.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(payload.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @web_ns.route("/forgot-password/resets") @@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 5847f4ae3a..e8053acdfd 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource): ) args = parser.parse_args() - user_email = args["email"] + user_email = args["email"].lower() token_data = WebAppAuthService.get_email_code_login_data(args["token"]) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + if normalized_token_email != user_email: raise InvalidEmailError() if token_data["code"] != args["code"]: raise EmailCodeError() WebAppAuthService.revoke_email_code_login_token(args["token"]) - account = WebAppAuthService.get_user_through_email(user_email) + account = WebAppAuthService.get_user_through_email(token_email) if not account: raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(user_email) response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) return response diff --git a/api/services/account_service.py b/api/services/account_service.py index d38c9d5a66..709ef749bc 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from hashlib import sha256 from typing import Any, cast from pydantic import BaseModel -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -748,6 +748,21 @@ class AccountService: cls.email_code_login_rate_limiter.increment_rate_limit(email) return token + @staticmethod + def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + """ + Retrieve an account by email and fall back to the lowercase email if the original lookup fails. + + This keeps backward compatibility for older records that stored uppercase emails while the + rest of the system gradually normalizes new inputs. + """ + query_session = session or db.session + account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account + + return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @@ -1363,16 +1378,22 @@ class RegisterService: if not inviter: raise ValueError("Inviter is required") + normalized_email = email.lower() + """Invite new member""" with Session(db.engine) as session: - account = session.query(Account).filter_by(email=email).first() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") - name = email.split("@")[0] + name = normalized_email.split("@")[0] account = cls.register( - email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True + email=normalized_email, + name=name, + language=language, + status=AccountStatus.PENDING, + is_setup=True, ) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) @@ -1394,7 +1415,7 @@ class RegisterService: # send email send_invite_member_mail_task.delay( language=language, - to=email, + to=account.email, token=token, inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, @@ -1493,6 +1514,16 @@ class RegisterService: invitation: dict = json.loads(data) return invitation + @classmethod + def get_invitation_with_case_fallback( + cls, workspace_id: str | None, email: str | None, token: str + ) -> dict[str, Any] | None: + invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) + if invitation or not email or email == email.lower(): + return invitation + normalized_email = email.lower() + return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token) + def _generate_refresh_token(length: int = 64): token = secrets.token_hex(length) diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 9bd797a45f..5ca0b63001 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -12,6 +12,7 @@ from libs.passport import PassportService from libs.password import compare_password from models import Account, AccountStatus from models.model import App, EndUser, Site +from services.account_service import AccountService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError @@ -32,7 +33,7 @@ class WebAppAuthService: @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - account = db.session.query(Account).filter_by(email=email).first() + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: raise AccountNotFoundError() @@ -52,7 +53,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = db.session.query(Account).where(Account.email == email).first() + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: return None diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index da21e0e358..d3e864a75a 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -40,7 +40,7 @@ class TestActivateCheckApi: "tenant": tenant, } - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation): """ Test checking valid invitation token. @@ -66,7 +66,7 @@ class TestActivateCheckApi: assert response["data"]["workspace_id"] == "workspace-123" assert response["data"]["email"] == "invitee@example.com" - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_invalid_invitation_token(self, mock_get_invitation, app): """ Test checking invalid invitation token. @@ -88,7 +88,7 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is False - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation): """ Test checking token without workspace ID. @@ -109,7 +109,7 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token") - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation): """ Test checking token without email parameter. @@ -130,6 +130,20 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") + def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation): + """Ensure token validation uses lowercase emails.""" + mock_get_invitation.return_value = mock_invitation + + with app.test_request_context( + "/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token" + ): + api = ActivateCheckApi() + response = api.get() + + assert response["is_valid"] is True + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + class TestActivateApi: """Test cases for account activation endpoint.""" @@ -212,7 +226,7 @@ class TestActivateApi: mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") mock_db.session.commit.assert_called_once() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_activation_with_invalid_token(self, mock_get_invitation, app): """ Test account activation with invalid token. @@ -241,7 +255,7 @@ class TestActivateApi: with pytest.raises(AlreadyActivateError): api.post() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_sets_interface_theme( @@ -290,7 +304,7 @@ class TestActivateApi: ("es-ES", "Europe/Madrid"), ], ) - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_with_different_locales( @@ -336,7 +350,7 @@ class TestActivateApi: assert mock_account.interface_language == language assert mock_account.timezone == timezone - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_returns_success_response( @@ -376,7 +390,7 @@ class TestActivateApi: # Assert assert response == {"result": "success"} - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_without_workspace_id( @@ -415,3 +429,37 @@ class TestActivateApi: # Assert assert response["result"] == "success" mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token") + + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") + @patch("controllers.console.auth.activate.RegisterService.revoke_token") + @patch("controllers.console.auth.activate.db") + def test_activation_normalizes_email_before_lookup( + self, + mock_db, + mock_revoke_token, + mock_get_invitation, + app, + mock_invitation, + mock_account, + ): + """Ensure uppercase emails are normalized before lookup and revocation.""" + mock_get_invitation.return_value = mock_invitation + + with app.test_request_context( + "/activate", + method="POST", + json={ + "workspace_id": "workspace-123", + "email": "Invitee@Example.com", + "token": "valid_token", + "name": "John Doe", + "interface_language": "en-US", + "timezone": "UTC", + }, + ): + api = ActivateApi() + response = api.post() + + assert response["result"] == "success" + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index eb21920117..cb4fe40944 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -34,7 +34,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_allowed( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): @@ -67,7 +67,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_wrong_password_returns_error( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db ): @@ -100,7 +100,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_disabled( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/unit_tests/controllers/console/auth/test_email_register.py new file mode 100644 index 0000000000..724c80f18c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register.py @@ -0,0 +1,177 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.email_register import ( + EmailRegisterCheckApi, + EmailRegisterResetApi, + EmailRegisterSendEmailApi, +) +from services.account_service import AccountService + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestEmailRegisterSendEmailApi: + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") + @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") + @patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_send_email_normalizes_and_falls_back( + self, + mock_extract_ip, + mock_is_email_send_ip_limit, + mock_is_freeze, + mock_send_mail, + mock_get_account, + mock_session_cls, + app, + ): + mock_send_mail.return_value = "token-123" + mock_is_freeze.return_value = False + mock_account = MagicMock() + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = mock_account + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register/send-email", + method="POST", + json={"email": "Invitee@Example.com", "language": "en-US"}, + ): + response = EmailRegisterSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_is_freeze.assert_called_once_with("invitee@example.com") + mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) + mock_extract_ip.assert_called_once() + mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") + + +class TestEmailRegisterCheckApi: + @patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.generate_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit") + def test_validity_normalizes_email_before_checks( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register/validity", + method="POST", + json={"email": "User@Example.com", "code": "4321", "token": "token-123"}, + ): + response = EmailRegisterCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("user@example.com") + mock_generate_token.assert_called_once_with( + "user@example.com", code="4321", additional_data={"phase": "register"} + ) + mock_reset_rate.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke.assert_called_once_with("token-123") + + +class TestEmailRegisterResetApi: + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_creates_account_with_normalized_email( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_get_account, + mock_session_cls, + mock_create_account, + mock_login, + mock_reset_login_rate, + app, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = None + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register", + method="POST", + json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"}, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!") + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + + assert account is expected_account + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py new file mode 100644 index 0000000000..8403777dc9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, +) +from services.account_service import AccountService + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1") + def test_send_normalizes_email( + self, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_account, + mock_session_cls, + app, + ): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_send_email.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + controller_features = SimpleNamespace(is_allow_register=True) + with ( + patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), + patch( + "controllers.console.auth.forgot_password.FeatureService.get_system_features", + return_value=controller_features, + ), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_send_email.assert_called_once_with( + account=mock_account, + email="user@example.com", + language="zh-Hans", + is_allow_register=True, + ) + mock_is_ip_limit.assert_called_once_with("127.0.0.1") + mock_extract_ip.assert_called_once() + + +class TestForgotPasswordCheckApi: + @patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_check_normalizes_email( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with ( + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password/validity", + method="POST", + json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("admin@example.com") + mock_generate_token.assert_called_once_with( + "Admin@Example.com", + code="4321", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("admin@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + + +class TestForgotPasswordResetApi: + @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + def test_reset_fetches_account_with_original_email( + self, + mock_get_reset_data, + mock_revoke_token, + mock_get_account, + mock_session_cls, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} + mock_account = MagicMock() + mock_get_account.return_value = mock_account + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with ( + patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_reset_data.assert_called_once_with("token-123") + mock_revoke_token.assert_called_once_with("token-123") + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_update_account.assert_called_once() + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + + assert account is expected_account + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 3a2cf7bad7..560971206f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -76,7 +76,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -120,7 +120,7 @@ class TestLoginApi: response = login_api.post() # Assert - mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!") + mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None) mock_login.assert_called_once() mock_reset_rate_limit.assert_called_once_with("test@example.com") assert response.json["result"] == "success" @@ -128,7 +128,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -182,7 +182,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login rejection when rate limit is exceeded. @@ -230,7 +230,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") def test_login_fails_with_invalid_credentials( @@ -269,7 +269,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app @@ -298,7 +298,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.FeatureService.get_system_features") @@ -343,7 +343,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login failure when invitation email doesn't match login email. @@ -371,6 +371,52 @@ class TestLoginApi: with pytest.raises(InvalidEmailError): login_api.post() + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) + @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") + @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") + @patch("controllers.console.auth.login.TenantService.get_join_tenants") + @patch("controllers.console.auth.login.AccountService.login") + @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") + def test_login_retries_with_lowercase_email( + self, + mock_reset_rate_limit, + mock_login_service, + mock_get_tenants, + mock_add_rate_limit, + mock_authenticate, + mock_get_invitation, + mock_is_rate_limit, + mock_db, + app, + mock_account, + mock_token_pair, + ): + """Test that login retries with lowercase email when uppercase lookup fails.""" + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_is_rate_limit.return_value = False + mock_get_invitation.return_value = None + mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account] + mock_get_tenants.return_value = [MagicMock()] + mock_login_service.return_value = mock_token_pair + + with app.test_request_context( + "/login", + method="POST", + json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")}, + ): + response = LoginApi().post() + + assert response.json["result"] == "success" + assert mock_authenticate.call_args_list == [ + (("Upper@Example.com", "ValidPass123!", None), {}), + (("upper@example.com", "ValidPass123!", None), {}), + ] + mock_add_rate_limit.assert_not_called() + mock_reset_rate_limit.assert_called_once_with("upper@example.com") + class TestLogoutApi: """Test cases for the LogoutApi endpoint.""" diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 3ddfcdb832..6345c2ab23 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -12,6 +12,7 @@ from controllers.console.auth.oauth import ( ) from libs.oauth import OAuthUserInfo from models.account import AccountStatus +from services.account_service import AccountService from services.errors.account import AccountRegisterError @@ -215,6 +216,34 @@ class TestOAuthCallback: assert status_code == 400 assert response["error"] == expected_error + @patch("controllers.console.auth.oauth.dify_config") + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.redirect") + def test_invitation_comparison_is_case_insensitive( + self, + mock_redirect, + mock_register_service, + mock_get_providers, + mock_config, + resource, + app, + oauth_setup, + ): + mock_config.CONSOLE_WEB_URL = "http://localhost:3000" + oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo( + id="123", name="Test User", email="User@Example.com" + ) + mock_get_providers.return_value = {"github": oauth_setup["provider"]} + mock_register_service.is_valid_invite_token.return_value = True + mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"} + + with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"): + resource.get("github") + + mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123") + mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123") + @pytest.mark.parametrize( ("account_status", "expected_redirect"), [ @@ -395,12 +424,12 @@ class TestAccountGeneration: account.name = "Test User" return account - @patch("controllers.console.auth.oauth.db") - @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.oauth.Session") - @patch("controllers.console.auth.oauth.select") + @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account + self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account ): # Mock db.engine for Session creation mock_db.engine = MagicMock() @@ -410,15 +439,31 @@ class TestAccountGeneration: result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account mock_account_model.get_by_openid.assert_called_once_with("github", "123") + mock_get_account.assert_not_called() - # Test fallback to email + # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account + mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + + def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + mock_session = MagicMock() + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] + + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + + assert result == expected_account + assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( ("allow_register", "existing_account", "should_create"), @@ -466,6 +511,35 @@ class TestAccountGeneration: mock_register_service.register.assert_called_once_with( email="test@example.com", name="Test User", password=None, open_id="123", provider="github" ) + else: + mock_register_service.register.assert_not_called() + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + @patch("controllers.console.auth.oauth.db") + def test_should_register_with_lowercase_email( + self, + mock_db, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app, + ): + user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "en-US"}): + _generate_account("github", user_info) + + mock_register_service.register.assert_called_once_with( + email="upper@example.com", name="Test User", password=None, open_id="123", provider="github" + ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @patch("controllers.console.auth.oauth.TenantService") diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/unit_tests/controllers/console/auth/test_password_reset.py index f584952a00..9488cf528e 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/unit_tests/controllers/console/auth/test_password_reset.py @@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError +@pytest.fixture(autouse=True) +def _mock_forgot_password_session(): + with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_session_cls.return_value.__exit__.return_value = None + yield mock_session + + +@pytest.fixture(autouse=True) +def _mock_forgot_password_db(): + with patch("controllers.console.auth.forgot_password.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db + + class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi: return account @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") def test_send_reset_email_success( self, mock_get_features, mock_send_email, - mock_select, - mock_session, + mock_get_account, mock_is_ip_limit, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_is_ip_limit.return_value = False - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" mock_get_features.return_value.is_allow_register = True @@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi: ], ) @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") def test_send_reset_email_language_handling( self, mock_get_features, mock_send_email, - mock_select, - mock_session, + mock_get_account, mock_is_ip_limit, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_is_ip_limit.return_value = False - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_send_email.return_value = "token" mock_get_features.return_value.is_allow_register = True @@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi: assert response["email"] == "test@example.com" assert response["token"] == "new_token" mock_revoke_token.assert_called_once_with("old_token") + mock_generate_token.assert_called_once_with( + "test@example.com", code="123456", additional_data={"phase": "reset"} + ) mock_reset_rate_limit.assert_called_once_with("test@example.com") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + def test_verify_code_preserves_token_email_case( + self, + mock_reset_rate_limit, + mock_generate_token, + mock_revoke_token, + mock_get_data, + mock_is_rate_limit, + mock_db, + app, + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} + mock_generate_token.return_value = (None, "fresh-token") + + with app.test_request_context( + "/forgot-password/validity", + method="POST", + json={"email": "user@example.com", "code": "999888", "token": "upper_token"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"} + mock_generate_token.assert_called_once_with( + "User@Example.com", code="999888", additional_data={"phase": "reset"} + ) + mock_revoke_token.assert_called_once_with("upper_token") + mock_reset_rate_limit.assert_called_once_with("user@example.com") + @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): @@ -355,20 +395,16 @@ class TestForgotPasswordResetApi: return account @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, - mock_select, - mock_session, + mock_get_account, mock_revoke_token, mock_get_data, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -383,11 +419,8 @@ class TestForgotPasswordResetApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act @@ -475,13 +508,11 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") def test_reset_password_account_not_found( - self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app + self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app ): """ Test password reset for non-existent account. @@ -491,11 +522,8 @@ class TestForgotPasswordResetApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = None # Act & Assert with app.test_request_context( diff --git a/api/tests/unit_tests/controllers/console/test_setup.py b/api/tests/unit_tests/controllers/console/test_setup.py new file mode 100644 index 0000000000..e7882dcd2b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_setup.py @@ -0,0 +1,39 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from controllers.console.setup import SetupApi + + +class TestSetupApi: + def test_post_lowercases_email_before_register(self): + """Ensure setup registration normalizes email casing.""" + payload = { + "email": "Admin@Example.com", + "name": "Admin User", + "password": "ValidPass123!", + "language": "en-US", + } + setup_api = SetupApi(api=None) + + mock_console_ns = SimpleNamespace(payload=payload) + + with ( + patch("controllers.console.setup.console_ns", mock_console_ns), + patch("controllers.console.setup.get_setup_status", return_value=False), + patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0), + patch("controllers.console.setup.get_init_validate_status", return_value=True), + patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"), + patch("controllers.console.setup.request", object()), + patch("controllers.console.setup.RegisterService.setup") as mock_register, + ): + response, status = setup_api.post() + + assert response == {"result": "success"} + assert status == 201 + mock_register.assert_called_once_with( + email="admin@example.com", + name=payload["name"], + password=payload["password"], + ip_address="127.0.0.1", + language=payload["language"], + ) diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py new file mode 100644 index 0000000000..9afc1c4166 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -0,0 +1,247 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, g + +from controllers.console.workspace.account import ( + AccountDeleteUpdateFeedbackApi, + ChangeEmailCheckApi, + ChangeEmailResetApi, + ChangeEmailSendEmailApi, + CheckEmailUnique, +) +from models import Account +from services.account_service import AccountService + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["RESTX_MASK_HEADER"] = "X-Fields" + app.login_manager = SimpleNamespace(_load_user=lambda: None) + return app + + +def _mock_wraps_db(mock_db): + mock_db.session.query.return_value.first.return_value = MagicMock() + + +def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: + tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") + account = Account(name=account_id, email=email) + account.email = email + account.id = account_id + account.status = "active" + account._current_tenant = tenant_obj + return account + + +def _set_logged_in_user(account: Account): + g._login_user = account + g._current_tenant = account.current_tenant + + +class TestChangeEmailSend: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_normalize_new_email_phase( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = {"email": "current@example.com"} + mock_send_email.return_value = "token-abc" + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailSendEmailApi().post() + + assert response == {"result": "success", "data": "token-abc"} + mock_send_email.assert_called_once_with( + account=None, + email="new@example.com", + old_email="current@example.com", + language="en-US", + phase="new_email", + ) + mock_extract_ip.assert_called_once() + mock_is_ip_limit.assert_called_once_with("127.0.0.1") + mock_csrf.assert_called_once() + + +class TestChangeEmailValidity: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_validate_with_normalized_email( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("user@example.com", "acc2") + mock_current_account.return_value = (mock_account, None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_generate_token.return_value = (None, "new-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_is_rate_limit.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + mock_generate_token.assert_called_once_with( + "user@example.com", code="1234", old_email="old@example.com", additional_data={} + ) + mock_reset_rate.assert_called_once_with("user@example.com") + mock_csrf.assert_called_once() + + +class TestChangeEmailReset: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_normalize_new_email_before_update( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_account_after_update = _build_account("new@example.com", "acc3-updated") + mock_update_account.return_value = mock_account_after_update + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "New@Example.com", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + ChangeEmailResetApi().post() + + mock_is_freeze.assert_called_once_with("new@example.com") + mock_check_unique.assert_called_once_with("new@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_update_account.assert_called_once_with(current_user, email="new@example.com") + mock_send_notify.assert_called_once_with(email="new@example.com") + mock_csrf.assert_called_once() + + +class TestAccountDeletionFeedback: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") + def test_should_normalize_feedback_email(self, mock_update, mock_db, app): + _mock_wraps_db(mock_db) + with app.test_request_context( + "/account/delete/feedback", + method="POST", + json={"email": "User@Example.com", "feedback": "test"}, + ): + response = AccountDeleteUpdateFeedbackApi().post() + + assert response == {"result": "success"} + mock_update.assert_called_once_with("User@Example.com", "test") + + +class TestCheckEmailUnique: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): + _mock_wraps_db(mock_db) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + + with app.test_request_context( + "/account/change-email/check-email-unique", + method="POST", + json={"email": "Case@Test.com"}, + ): + response = CheckEmailUnique().post() + + assert response == {"result": "success"} + mock_is_freeze.assert_called_once_with("case@test.com") + mock_check_unique.assert_called_once_with("case@test.com") + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + session = MagicMock() + first = MagicMock() + first.scalar_one_or_none.return_value = None + second = MagicMock() + expected_account = MagicMock() + second.scalar_one_or_none.return_value = expected_account + session.execute.side_effect = [first, second] + + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + + assert result is expected_account + assert session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py new file mode 100644 index 0000000000..368892b922 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -0,0 +1,82 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, g + +from controllers.console.workspace.members import MemberInviteEmailApi +from models.account import Account, TenantAccountRole + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + flask_app.login_manager = SimpleNamespace(_load_user=lambda: None) + return flask_app + + +def _mock_wraps_db(mock_db): + mock_db.session.query.return_value.first.return_value = MagicMock() + + +def _build_feature_flags(): + placeholder_quota = SimpleNamespace(limit=0, size=0) + workspace_members = SimpleNamespace(is_available=lambda count: True) + return SimpleNamespace( + billing=SimpleNamespace(enabled=False), + workspace_members=workspace_members, + members=placeholder_quota, + apps=placeholder_quota, + vector_space=placeholder_quota, + documents_upload_quota=placeholder_quota, + annotation_quota_limit=placeholder_quota, + ) + + +class TestMemberInviteEmailApi: + @patch("controllers.console.workspace.members.FeatureService.get_features") + @patch("controllers.console.workspace.members.RegisterService.invite_new_member") + @patch("controllers.console.workspace.members.current_account_with_tenant") + @patch("controllers.console.wraps.db") + @patch("libs.login.check_csrf_token", return_value=None) + def test_invite_normalizes_emails( + self, + mock_csrf, + mock_db, + mock_current_account, + mock_invite_member, + mock_get_features, + app, + ): + _mock_wraps_db(mock_db) + mock_get_features.return_value = _build_feature_flags() + mock_invite_member.return_value = "token-abc" + + tenant = SimpleNamespace(id="tenant-1", name="Test Tenant") + inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active") + mock_current_account.return_value = (inviter, tenant.id) + + with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"): + with app.test_request_context( + "/workspaces/current/members/invite-email", + method="POST", + json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"}, + ): + account = Account(name="tester", email="tester@example.com") + account._current_tenant = tenant + g._login_user = account + g._current_tenant = tenant + response, status_code = MemberInviteEmailApi().post() + + assert status_code == 201 + assert response["invitation_results"][0]["email"] == "user@example.com" + + assert mock_invite_member.call_count == 1 + call_args = mock_invite_member.call_args + assert call_args.kwargs["tenant"] == tenant + assert call_args.kwargs["email"] == "User@Example.com" + assert call_args.kwargs["language"] == "en-US" + assert call_args.kwargs["role"] == TenantAccountRole.EDITOR + assert call_args.kwargs["inviter"] == inviter + mock_csrf.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_forgot_password.py b/api/tests/unit_tests/controllers/web/test_forgot_password.py deleted file mode 100644 index d7c0d24f14..0000000000 --- a/api/tests/unit_tests/controllers/web/test_forgot_password.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Unit tests for controllers.web.forgot_password endpoints.""" - -from __future__ import annotations - -import base64 -import builtins -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from flask.views import MethodView - -# Ensure flask_restx.api finds MethodView during import. -if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView # type: ignore[attr-defined] - - -def _load_controller_module(): - """Import controllers.web.forgot_password using a stub package.""" - - import importlib - import importlib.util - import sys - from types import ModuleType - - parent_module_name = "controllers.web" - module_name = f"{parent_module_name}.forgot_password" - - if parent_module_name not in sys.modules: - from flask_restx import Namespace - - stub = ModuleType(parent_module_name) - stub.__file__ = "controllers/web/__init__.py" - stub.__path__ = ["controllers/web"] - stub.__package__ = "controllers" - stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True) - stub.web_ns = Namespace("web", description="Web API", path="/") - sys.modules[parent_module_name] = stub - - return importlib.import_module(module_name) - - -forgot_password_module = _load_controller_module() -ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi -ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi -ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi - - -@pytest.fixture -def app() -> Flask: - """Configure a minimal Flask app for request contexts.""" - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - -@pytest.fixture(autouse=True) -def _enable_web_endpoint_guards(): - """Stub enterprise and feature toggles used by route decorators.""" - - features = SimpleNamespace(enable_email_password_login=True) - with ( - patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True), - patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), - patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features), - ): - yield - - -@pytest.fixture(autouse=True) -def _mock_controller_db(): - """Replace controller-level db reference with a simple stub.""" - - fake_db = SimpleNamespace(engine=MagicMock(name="engine")) - fake_wraps_db = SimpleNamespace( - session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True)))) - ) - with ( - patch("controllers.web.forgot_password.db", fake_db), - patch("controllers.console.wraps.db", fake_wraps_db), - ): - yield fake_db - - -@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token") -@patch("controllers.web.forgot_password.Session") -@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) -@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10") -def test_send_reset_email_success( - mock_extract_ip: MagicMock, - mock_is_ip_limit: MagicMock, - mock_session: MagicMock, - mock_send_email: MagicMock, - app: Flask, -): - """POST /forgot-password returns token when email exists and limits allow.""" - - mock_account = MagicMock() - session_ctx = MagicMock() - mock_session.return_value.__enter__.return_value = session_ctx - session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account - - with app.test_request_context( - "/forgot-password", - method="POST", - json={"email": "user@example.com"}, - ): - response = ForgotPasswordSendEmailApi().post() - - assert response == {"result": "success", "data": "reset-token"} - mock_extract_ip.assert_called_once() - mock_is_ip_limit.assert_called_once_with("203.0.113.10") - mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US") - - -@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") -@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token")) -@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") -@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") -@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False) -def test_check_token_success( - mock_is_rate_limited: MagicMock, - mock_get_data: MagicMock, - mock_revoke: MagicMock, - mock_generate: MagicMock, - mock_reset_limit: MagicMock, - app: Flask, -): - """POST /forgot-password/validity validates the code and refreshes token.""" - - mock_get_data.return_value = {"email": "user@example.com", "code": "123456"} - - with app.test_request_context( - "/forgot-password/validity", - method="POST", - json={"email": "user@example.com", "code": "123456", "token": "old-token"}, - ): - response = ForgotPasswordCheckApi().post() - - assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} - mock_is_rate_limited.assert_called_once_with("user@example.com") - mock_get_data.assert_called_once_with("old-token") - mock_revoke.assert_called_once_with("old-token") - mock_generate.assert_called_once_with( - "user@example.com", - code="123456", - additional_data={"phase": "reset"}, - ) - mock_reset_limit.assert_called_once_with("user@example.com") - - -@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") -@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") -@patch("controllers.web.forgot_password.Session") -@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") -@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") -def test_reset_password_success( - mock_get_data: MagicMock, - mock_revoke_token: MagicMock, - mock_session: MagicMock, - mock_token_bytes: MagicMock, - mock_hash_password: MagicMock, - app: Flask, -): - """POST /forgot-password/resets updates the stored password when token is valid.""" - - mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"} - account = MagicMock() - session_ctx = MagicMock() - mock_session.return_value.__enter__.return_value = session_ctx - session_ctx.execute.return_value.scalar_one_or_none.return_value = account - - with app.test_request_context( - "/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() - - assert response == {"result": "success"} - mock_get_data.assert_called_once_with("reset-token") - mock_revoke_token.assert_called_once_with("reset-token") - mock_token_bytes.assert_called_once_with(16) - mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef") - expected_password = base64.b64encode(b"hashed-value").decode() - assert account.password == expected_password - expected_salt = base64.b64encode(b"0123456789abcdef").decode() - assert account.password_salt == expected_salt - session_ctx.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_forgot_password.py b/api/tests/unit_tests/controllers/web/test_web_forgot_password.py new file mode 100644 index 0000000000..3d7c319947 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_forgot_password.py @@ -0,0 +1,226 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, +) + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture(autouse=True) +def _patch_wraps(): + wraps_features = SimpleNamespace(enable_email_password_login=True) + dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD") + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config", dify_settings), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + yield + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") + @patch("controllers.web.forgot_password.Session") + def test_should_normalize_email_before_sending( + self, + mock_session_cls, + mock_extract_ip, + mock_rate_limit, + mock_get_account, + mock_send_mail, + app, + ): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_send_mail.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") + mock_extract_ip.assert_called_once() + mock_rate_limit.assert_called_once_with("127.0.0.1") + + +class TestForgotPasswordCheckApi: + @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_should_normalize_email_for_validity_checks( + self, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} + mock_generate_token.return_value = (None, "new-token") + + with app.test_request_context( + "/web/forgot-password/validity", + method="POST", + json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_is_rate_limit.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + mock_generate_token.assert_called_once_with( + "User@Example.com", + code="1234", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("user@example.com") + + @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_should_preserve_token_email_case( + self, + mock_is_rate_limit, + mock_get_data, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"} + mock_generate_token.return_value = (None, "fresh-token") + + with app.test_request_context( + "/web/forgot-password/validity", + method="POST", + json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"} + mock_generate_token.assert_called_once_with( + "MixedCase@Example.com", + code="5678", + additional_data={"phase": "reset"}, + ) + mock_revoke_token.assert_called_once_with("token-upper") + mock_reset_rate.assert_called_once_with("mixedcase@example.com") + + +class TestForgotPasswordResetApi: + @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + def test_should_fetch_account_with_fallback( + self, + mock_get_reset_data, + mock_revoke_token, + mock_session_cls, + mock_get_account, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_update_account.assert_called_once() + mock_revoke_token.assert_called_once_with("token-123") + + @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") + @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") + @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + def test_should_update_password_and_commit( + self, + mock_get_account, + mock_get_reset_data, + mock_revoke_token, + mock_session_cls, + mock_token_bytes, + mock_hash_password, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} + account = MagicMock() + mock_get_account.return_value = account + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_reset_data.assert_called_once_with("reset-token") + mock_revoke_token.assert_called_once_with("reset-token") + mock_token_bytes.assert_called_once_with(16) + mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef") + expected_password = base64.b64encode(b"hashed-value").decode() + assert account.password == expected_password + expected_salt = base64.b64encode(b"0123456789abcdef").decode() + assert account.password_salt == expected_salt + mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py new file mode 100644 index 0000000000..e62993e8d5 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -0,0 +1,91 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi + + +def encode_code(code: str) -> str: + return base64.b64encode(code.encode("utf-8")).decode() + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture(autouse=True) +def _patch_wraps(): + wraps_features = SimpleNamespace(enable_email_password_login=True) + console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD") + web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True) + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config", console_dify), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + patch("controllers.web.login.dify_config", web_dify), + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + yield + + +class TestEmailCodeLoginSendEmailApi: + @patch("controllers.web.login.WebAppAuthService.send_email_code_login_email") + @patch("controllers.web.login.WebAppAuthService.get_user_through_email") + def test_should_fetch_account_with_original_email( + self, + mock_get_user, + mock_send_email, + app, + ): + mock_account = MagicMock() + mock_get_user.return_value = mock_account + mock_send_email.return_value = "token-123" + + with app.test_request_context( + "/web/email-code-login", + method="POST", + json={"email": "User@Example.com", "language": "en-US"}, + ): + response = EmailCodeLoginSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_user.assert_called_once_with("User@Example.com") + mock_send_email.assert_called_once_with(account=mock_account, language="en-US") + + +class TestEmailCodeLoginApi: + @patch("controllers.web.login.AccountService.reset_login_error_rate_limit") + @patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token") + @patch("controllers.web.login.WebAppAuthService.get_user_through_email") + @patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token") + @patch("controllers.web.login.WebAppAuthService.get_email_code_login_data") + def test_should_normalize_email_before_validating( + self, + mock_get_token_data, + mock_revoke_token, + mock_get_user, + mock_login, + mock_reset_login_rate, + app, + ): + mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} + mock_get_user.return_value = MagicMock() + + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + response = EmailCodeLoginApi().post() + + assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}} + mock_get_user.assert_called_once_with("User@Example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_login.assert_called_once() + mock_reset_login_rate.assert_called_once_with("user@example.com") diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index e35ba74c56..8ae20f35d8 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from configs import dify_config -from models.account import Account +from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import ( AccountAlreadyInTenantError, @@ -1147,9 +1147,13 @@ class TestRegisterService: mock_session = MagicMock() mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account - with patch("services.account_service.Session") as mock_session_class: + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = None # Mock RegisterService.register mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( @@ -1182,9 +1186,59 @@ class TestRegisterService: email="newuser@example.com", name="newuser", language="en-US", - status="pending", + status=AccountStatus.PENDING, is_setup=True, ) + mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) + + def test_invite_new_member_normalizes_new_account_email( + self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies + ): + """Ensure inviting with mixed-case email normalizes before registering.""" + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") + mixed_email = "Invitee@Example.com" + + mock_session = MagicMock() + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = None + + mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending" + ) + with patch("services.account_service.RegisterService.register") as mock_register: + mock_register.return_value = mock_new_account + with ( + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant, + patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token, + ): + mock_generate_token.return_value = "invite-token-abc" + + RegisterService.invite_new_member( + tenant=mock_tenant, + email=mixed_email, + language="en-US", + role="normal", + inviter=mock_inviter, + ) + + mock_register.assert_called_once_with( + email="invitee@example.com", + name="invitee", + language="en-US", + status=AccountStatus.PENDING, + is_setup=True, + ) + mock_lookup.assert_called_once_with(mixed_email, session=mock_session) + mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account) @@ -1207,9 +1261,13 @@ class TestRegisterService: mock_session = MagicMock() mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account - with patch("services.account_service.Session") as mock_session_class: + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = mock_existing_account # Mock the db.session.query for TenantAccountJoin mock_db_query = MagicMock() @@ -1238,6 +1296,7 @@ class TestRegisterService: mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_task_dependencies.delay.assert_called_once() + mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): """Test inviting a member who is already in the tenant.""" @@ -1251,7 +1310,6 @@ class TestRegisterService: # Mock database queries query_results = { - ("Account", "email", "existing@example.com"): mock_existing_account, ( "TenantAccountJoin", "tenant_id", @@ -1261,7 +1319,11 @@ class TestRegisterService: ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) # Mock TenantService methods - with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission: + with ( + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + ): + mock_lookup.return_value = mock_existing_account # Execute test and verify exception self._assert_exception_raised( AccountAlreadyInTenantError, @@ -1272,6 +1334,7 @@ class TestRegisterService: role="normal", inviter=mock_inviter, ) + mock_lookup.assert_called_once() def test_invite_new_member_no_inviter(self): """Test inviting a member without providing an inviter.""" @@ -1497,6 +1560,30 @@ class TestRegisterService: # Verify results assert result is None + def test_get_invitation_with_case_fallback_returns_initial_match(self): + """Fallback helper should return the initial invitation when present.""" + invitation = {"workspace_id": "tenant-456"} + with patch( + "services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation + ) as mock_get: + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123") + + def test_get_invitation_with_case_fallback_retries_with_lowercase(self): + """Fallback helper should retry with lowercase email when needed.""" + invitation = {"workspace_id": "tenant-456"} + with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get: + mock_get.side_effect = [None, invitation] + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + assert mock_get.call_args_list == [ + (("tenant-456", "User@Test.com", "token-123"),), + (("tenant-456", "user@test.com", "token-123"),), + ] + # ==================== Helper Method Tests ==================== def test_get_invitation_token_key(self): From 1fbdf6b465b8f2df07fb9db2cc09adc42be79b6b Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:59:49 +0800 Subject: [PATCH 16/29] refactor(web): setup status caching (#30798) --- web/__tests__/embedded-user-id-store.test.tsx | 1 + web/app/components/app-initializer.tsx | 13 +- .../access-control.spec.tsx | 1 - web/app/install/installForm.spec.tsx | 8 + web/context/global-public-context.tsx | 53 ++++--- web/context/web-app-context.tsx | 4 +- web/hooks/use-document-title.spec.ts | 24 ++- web/hooks/use-document-title.ts | 4 +- web/utils/setup-status.spec.ts | 139 ++++++++++++++++++ web/utils/setup-status.ts | 21 +++ 10 files changed, 229 insertions(+), 39 deletions(-) create mode 100644 web/utils/setup-status.spec.ts create mode 100644 web/utils/setup-status.ts diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 276b22bcd7..901218e76b 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => { ) return { useGlobalPublicStore, + useIsSystemFeaturesPending: () => false, } }) diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e30646eb3f..3410ecbe9a 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -9,8 +9,8 @@ import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' -import { fetchSetupStatus } from '@/service/common' import { sendGAEvent } from '@/utils/gtag' +import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' import { trackEvent } from './base/amplitude' @@ -33,15 +33,8 @@ export const AppInitializer = ({ const isSetupFinished = useCallback(async () => { try { - if (localStorage.getItem('setup_status') === 'finished') - return true - const setUpStatus = await fetchSetupStatus() - if (setUpStatus.step !== 'finished') { - localStorage.removeItem('setup_status') - return false - } - localStorage.setItem('setup_status', 'finished') - return true + const setUpStatus = await fetchSetupStatusWithCache() + return setUpStatus.step === 'finished' } catch (error) { console.error(error) diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index dd9acd3479..0624cb316b 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -125,7 +125,6 @@ const resetAccessControlStore = () => { const resetGlobalStore = () => { useGlobalPublicStore.setState({ systemFeatures: defaultSystemFeatures, - isGlobalPending: false, }) } diff --git a/web/app/install/installForm.spec.tsx b/web/app/install/installForm.spec.tsx index 74602f916a..5efd5cebb6 100644 --- a/web/app/install/installForm.spec.tsx +++ b/web/app/install/installForm.spec.tsx @@ -19,6 +19,14 @@ vi.mock('@/service/common', () => ({ getSystemFeatures: vi.fn(), })) +vi.mock('@/context/global-public-context', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useIsSystemFeaturesPending: () => false, + } +}) + const mockFetchSetupStatus = vi.mocked(fetchSetupStatus) const mockFetchInitValidateStatus = vi.mocked(fetchInitValidateStatus) const mockSetup = vi.mocked(setup) diff --git a/web/context/global-public-context.tsx b/web/context/global-public-context.tsx index c2742bb7a9..9b2b0834e2 100644 --- a/web/context/global-public-context.tsx +++ b/web/context/global-public-context.tsx @@ -2,42 +2,61 @@ import type { FC, PropsWithChildren } from 'react' import type { SystemFeatures } from '@/types/feature' import { useQuery } from '@tanstack/react-query' -import { useEffect } from 'react' import { create } from 'zustand' import Loading from '@/app/components/base/loading' import { getSystemFeatures } from '@/service/common' import { defaultSystemFeatures } from '@/types/feature' +import { fetchSetupStatusWithCache } from '@/utils/setup-status' type GlobalPublicStore = { - isGlobalPending: boolean - setIsGlobalPending: (isPending: boolean) => void systemFeatures: SystemFeatures setSystemFeatures: (systemFeatures: SystemFeatures) => void } export const useGlobalPublicStore = create(set => ({ - isGlobalPending: true, - setIsGlobalPending: (isPending: boolean) => set(() => ({ isGlobalPending: isPending })), systemFeatures: defaultSystemFeatures, setSystemFeatures: (systemFeatures: SystemFeatures) => set(() => ({ systemFeatures })), })) +const systemFeaturesQueryKey = ['systemFeatures'] as const +const setupStatusQueryKey = ['setupStatus'] as const + +async function fetchSystemFeatures() { + const data = await getSystemFeatures() + const { setSystemFeatures } = useGlobalPublicStore.getState() + setSystemFeatures({ ...defaultSystemFeatures, ...data }) + return data +} + +export function useSystemFeaturesQuery() { + return useQuery({ + queryKey: systemFeaturesQueryKey, + queryFn: fetchSystemFeatures, + }) +} + +export function useIsSystemFeaturesPending() { + const { isPending } = useSystemFeaturesQuery() + return isPending +} + +export function useSetupStatusQuery() { + return useQuery({ + queryKey: setupStatusQueryKey, + queryFn: fetchSetupStatusWithCache, + staleTime: Infinity, + }) +} + const GlobalPublicStoreProvider: FC = ({ children, }) => { - const { isPending, data } = useQuery({ - queryKey: ['systemFeatures'], - queryFn: getSystemFeatures, - }) - const { setSystemFeatures, setIsGlobalPending: setIsPending } = useGlobalPublicStore() - useEffect(() => { - if (data) - setSystemFeatures({ ...defaultSystemFeatures, ...data }) - }, [data, setSystemFeatures]) + // Fetch systemFeatures and setupStatus in parallel to reduce waterfall. + // setupStatus is prefetched here and cached in localStorage for AppInitializer. + const { isPending } = useSystemFeaturesQuery() - useEffect(() => { - setIsPending(isPending) - }, [isPending, setIsPending]) + // Prefetch setupStatus for AppInitializer (result not needed here) + useSetupStatusQuery() if (isPending) return
diff --git a/web/context/web-app-context.tsx b/web/context/web-app-context.tsx index e6680c95a5..c5488a565c 100644 --- a/web/context/web-app-context.tsx +++ b/web/context/web-app-context.tsx @@ -10,7 +10,7 @@ import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/ import Loading from '@/app/components/base/loading' import { AccessMode } from '@/models/access-control' import { useGetWebAppAccessModeByCode } from '@/service/use-share' -import { useGlobalPublicStore } from './global-public-context' +import { useIsSystemFeaturesPending } from './global-public-context' type WebAppStore = { shareCode: string | null @@ -65,7 +65,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => { } const WebAppStoreProvider: FC = ({ children }) => { - const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending) + const isGlobalPending = useIsSystemFeaturesPending() const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode) const updateShareCode = useWebAppStore(state => state.updateShareCode) const updateEmbeddedUserId = useWebAppStore(state => state.updateEmbeddedUserId) diff --git a/web/hooks/use-document-title.spec.ts b/web/hooks/use-document-title.spec.ts index 3909978591..efa72cac5c 100644 --- a/web/hooks/use-document-title.spec.ts +++ b/web/hooks/use-document-title.spec.ts @@ -1,5 +1,5 @@ import { act, renderHook } from '@testing-library/react' -import { useGlobalPublicStore } from '@/context/global-public-context' +import { useGlobalPublicStore, useIsSystemFeaturesPending } from '@/context/global-public-context' /** * Test suite for useDocumentTitle hook * @@ -15,6 +15,14 @@ import { useGlobalPublicStore } from '@/context/global-public-context' import { defaultSystemFeatures } from '@/types/feature' import useDocumentTitle from './use-document-title' +vi.mock('@/context/global-public-context', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useIsSystemFeaturesPending: vi.fn(() => false), + } +}) + vi.mock('@/service/common', () => ({ getSystemFeatures: vi.fn(() => ({ ...defaultSystemFeatures })), })) @@ -24,10 +32,12 @@ vi.mock('@/service/common', () => ({ * Title should remain empty to prevent flicker */ describe('title should be empty if systemFeatures is pending', () => { - act(() => { - useGlobalPublicStore.setState({ - systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } }, - isGlobalPending: true, + beforeEach(() => { + vi.mocked(useIsSystemFeaturesPending).mockReturnValue(true) + act(() => { + useGlobalPublicStore.setState({ + systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } }, + }) }) }) /** @@ -52,9 +62,9 @@ describe('title should be empty if systemFeatures is pending', () => { */ describe('use default branding', () => { beforeEach(() => { + vi.mocked(useIsSystemFeaturesPending).mockReturnValue(false) act(() => { useGlobalPublicStore.setState({ - isGlobalPending: false, systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } }, }) }) @@ -84,9 +94,9 @@ describe('use default branding', () => { */ describe('use specific branding', () => { beforeEach(() => { + vi.mocked(useIsSystemFeaturesPending).mockReturnValue(false) act(() => { useGlobalPublicStore.setState({ - isGlobalPending: false, systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: true, application_title: 'Test' } }, }) }) diff --git a/web/hooks/use-document-title.ts b/web/hooks/use-document-title.ts index bb69aeb20f..37b31a7dea 100644 --- a/web/hooks/use-document-title.ts +++ b/web/hooks/use-document-title.ts @@ -1,11 +1,11 @@ 'use client' import { useFavicon, useTitle } from 'ahooks' import { useEffect } from 'react' -import { useGlobalPublicStore } from '@/context/global-public-context' +import { useGlobalPublicStore, useIsSystemFeaturesPending } from '@/context/global-public-context' import { basePath } from '@/utils/var' export default function useDocumentTitle(title: string) { - const isPending = useGlobalPublicStore(s => s.isGlobalPending) + const isPending = useIsSystemFeaturesPending() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const prefix = title ? `${title} - ` : '' let titleStr = '' diff --git a/web/utils/setup-status.spec.ts b/web/utils/setup-status.spec.ts new file mode 100644 index 0000000000..be96b43eba --- /dev/null +++ b/web/utils/setup-status.spec.ts @@ -0,0 +1,139 @@ +import type { SetupStatusResponse } from '@/models/common' + +import { fetchSetupStatus } from '@/service/common' + +import { fetchSetupStatusWithCache } from './setup-status' + +vi.mock('@/service/common', () => ({ + fetchSetupStatus: vi.fn(), +})) + +const mockFetchSetupStatus = vi.mocked(fetchSetupStatus) + +describe('setup-status utilities', () => { + beforeEach(() => { + vi.clearAllMocks() + localStorage.clear() + }) + + describe('fetchSetupStatusWithCache', () => { + describe('when cache exists', () => { + it('should return cached finished status without API call', async () => { + localStorage.setItem('setup_status', 'finished') + + const result = await fetchSetupStatusWithCache() + + expect(result).toEqual({ step: 'finished' }) + expect(mockFetchSetupStatus).not.toHaveBeenCalled() + }) + + it('should not modify localStorage when returning cached value', async () => { + localStorage.setItem('setup_status', 'finished') + + await fetchSetupStatusWithCache() + + expect(localStorage.getItem('setup_status')).toBe('finished') + }) + }) + + describe('when cache does not exist', () => { + it('should call API and cache finished status', async () => { + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + expect(localStorage.getItem('setup_status')).toBe('finished') + }) + + it('should call API and remove cache when not finished', async () => { + const apiResponse: SetupStatusResponse = { step: 'not_started' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + expect(localStorage.getItem('setup_status')).toBeNull() + }) + + it('should clear stale cache when API returns not_started', async () => { + localStorage.setItem('setup_status', 'some_invalid_value') + const apiResponse: SetupStatusResponse = { step: 'not_started' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(result).toEqual(apiResponse) + expect(localStorage.getItem('setup_status')).toBeNull() + }) + }) + + describe('cache edge cases', () => { + it('should call API when cache value is empty string', async () => { + localStorage.setItem('setup_status', '') + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + }) + + it('should call API when cache value is not "finished"', async () => { + localStorage.setItem('setup_status', 'not_started') + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + }) + + it('should call API when localStorage key does not exist', async () => { + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + }) + }) + + describe('API response handling', () => { + it('should preserve setup_at from API response', async () => { + const setupDate = new Date('2024-01-01') + const apiResponse: SetupStatusResponse = { + step: 'finished', + setup_at: setupDate, + } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(result).toEqual(apiResponse) + expect(result.setup_at).toEqual(setupDate) + }) + + it('should propagate API errors', async () => { + const apiError = new Error('Network error') + mockFetchSetupStatus.mockRejectedValue(apiError) + + await expect(fetchSetupStatusWithCache()).rejects.toThrow('Network error') + }) + + it('should not update cache when API call fails', async () => { + mockFetchSetupStatus.mockRejectedValue(new Error('API error')) + + await expect(fetchSetupStatusWithCache()).rejects.toThrow() + + expect(localStorage.getItem('setup_status')).toBeNull() + }) + }) + }) +}) diff --git a/web/utils/setup-status.ts b/web/utils/setup-status.ts new file mode 100644 index 0000000000..7a2810bffd --- /dev/null +++ b/web/utils/setup-status.ts @@ -0,0 +1,21 @@ +import type { SetupStatusResponse } from '@/models/common' +import { fetchSetupStatus } from '@/service/common' + +const SETUP_STATUS_KEY = 'setup_status' + +const isSetupStatusCached = (): boolean => + localStorage.getItem(SETUP_STATUS_KEY) === 'finished' + +export const fetchSetupStatusWithCache = async (): Promise => { + if (isSetupStatusCached()) + return { step: 'finished' } + + const status = await fetchSetupStatus() + + if (status.step === 'finished') + localStorage.setItem(SETUP_STATUS_KEY, 'finished') + else + localStorage.removeItem(SETUP_STATUS_KEY) + + return status +} From a22cc5bc5e2e67fdb84995d840cdc2c4d6441d81 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 13 Jan 2026 17:49:13 +0800 Subject: [PATCH 17/29] chore: Bump Dify version to 1.11.3 (#30903) --- api/pyproject.toml | 2 +- api/uv.lock | 2 +- docker/docker-compose-template.yaml | 8 ++++---- docker/docker-compose.yaml | 8 ++++---- web/package.json | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 7d2d68bc8d..28bd591d17 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.11.2" +version = "1.11.3" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/uv.lock b/api/uv.lock index a999c4ee18..444c7f2f5a 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.11.2" +version = "1.11.3" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 709aff23df..aada39569e 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.2 + image: langgenius/dify-web:1.11.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 041f60aaa2..fcb07dda36 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -704,7 +704,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -746,7 +746,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -785,7 +785,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -815,7 +815,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.2 + image: langgenius/dify-web:1.11.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/package.json b/web/package.json index 4019e49cd9..44cc9196f4 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.11.2", + "version": "1.11.3", "private": true, "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "imports": { From fe07c810ba7fe96caf8e4ae6bf78ec4884fdf68a Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 13 Jan 2026 21:15:21 +0800 Subject: [PATCH 18/29] fix: fix instance is not bind to session (#30913) --- api/core/tools/workflow_as_tool/tool.py | 55 ++++++++++--------- .../core/tools/workflow_as_tool/test_tool.py | 40 ++++++++++++-- 2 files changed, 65 insertions(+), 30 deletions(-) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 81a1d54199..389db8a972 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -7,8 +7,8 @@ from typing import Any, cast from flask import has_request_context from sqlalchemy import select -from sqlalchemy.orm import Session +from core.db.session_factory import session_factory from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool @@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs.login import current_user from models import Account, Tenant @@ -230,30 +229,32 @@ class WorkflowTool(Tool): """ Resolve user from database (worker/Celery context). """ + with session_factory.create_session() as session: + tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) + tenant = session.scalar(tenant_stmt) + if not tenant: + return None + + user_stmt = select(Account).where(Account.id == user_id) + user = session.scalar(user_stmt) + if user: + user.current_tenant = tenant + session.expunge(user) + return user + + end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) + end_user = session.scalar(end_user_stmt) + if end_user: + session.expunge(end_user) + return end_user - tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) - tenant = db.session.scalar(tenant_stmt) - if not tenant: return None - user_stmt = select(Account).where(Account.id == user_id) - user = db.session.scalar(user_stmt) - if user: - user.current_tenant = tenant - return user - - end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) - end_user = db.session.scalar(end_user_stmt) - if end_user: - return end_user - - return None - def _get_workflow(self, app_id: str, version: str) -> Workflow: """ get the workflow by app id and version """ - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): if not version: stmt = ( select(Workflow) @@ -265,22 +266,24 @@ class WorkflowTool(Tool): stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) workflow = session.scalar(stmt) - if not workflow: - raise ValueError("workflow not found or not published") + if not workflow: + raise ValueError("workflow not found or not published") - return workflow + session.expunge(workflow) + return workflow def _get_app(self, app_id: str) -> App: """ get the app by app id """ stmt = select(App).where(App.id == app_id) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): app = session.scalar(stmt) - if not app: - raise ValueError("app not found") + if not app: + raise ValueError("app not found") - return app + session.expunge(app) + return app def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 5d180c7cbc..cd45292488 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M def scalar(self, _stmt): return self.results.pop(0) + # SQLAlchemy Session APIs used by code under test + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + # support `with session_factory.create_session() as session:` + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + tenant = SimpleNamespace(id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") - db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + # Monkeypatch session factory to return our stub session + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([tenant, None, end_user]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), @@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt def scalar(self, _stmt): return self.results.pop(0) - db_stub = SimpleNamespace(session=StubSession([None])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + # Monkeypatch session factory to return our stub session with no tenant + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([None]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), From a129e684cced304b8ffbf2d2af4d43e30e30a2a8 Mon Sep 17 00:00:00 2001 From: Yunlu Wen Date: Tue, 13 Jan 2026 22:37:39 +0800 Subject: [PATCH 19/29] feat: inject traceparent in enterprise api (#30895) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/enterprise/base.py | 14 +++++ .../services/enterprise/__init__.py | 0 .../test_traceparent_propagation.py | 59 +++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 api/tests/unit_tests/services/enterprise/__init__.py create mode 100644 api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index bdc960aa2d..e3832475aa 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -1,9 +1,14 @@ +import logging import os from collections.abc import Mapping from typing import Any import httpx +from core.helper.trace_id_helper import generate_traceparent_header + +logger = logging.getLogger(__name__) + class BaseRequest: proxies: Mapping[str, str] | None = { @@ -38,6 +43,15 @@ class BaseRequest: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" mounts = cls._build_mounts() + + try: + # ensure traceparent even when OTEL is disabled + traceparent = generate_traceparent_header() + if traceparent: + headers["traceparent"] = traceparent + except Exception: + logger.debug("Failed to generate traceparent header", exc_info=True) + with httpx.Client(mounts=mounts) as client: response = client.request(method, url, json=json, params=params, headers=headers) return response.json() diff --git a/api/tests/unit_tests/services/enterprise/__init__.py b/api/tests/unit_tests/services/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py new file mode 100644 index 0000000000..87c03f13a3 --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py @@ -0,0 +1,59 @@ +"""Unit tests for traceparent header propagation in EnterpriseRequest. + +This test module verifies that the W3C traceparent header is properly +generated and included in HTTP requests made by EnterpriseRequest. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from services.enterprise.base import EnterpriseRequest + + +class TestTraceparentPropagation: + """Unit tests for traceparent header propagation.""" + + @pytest.fixture + def mock_enterprise_config(self): + """Mock EnterpriseRequest configuration.""" + with ( + patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"), + patch.object(EnterpriseRequest, "secret_key", "test-secret-key"), + patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"), + ): + yield + + @pytest.fixture + def mock_httpx_client(self): + """Mock httpx.Client for testing.""" + with patch("services.enterprise.base.httpx.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value.__enter__.return_value = mock_client + mock_client_class.return_value.__exit__.return_value = None + + # Setup default response + mock_response = MagicMock() + mock_response.json.return_value = {"result": "success"} + mock_client.request.return_value = mock_response + + yield mock_client + + def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client): + """Test that traceparent header is included when successfully generated.""" + # Arrange + expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" + + with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent): + # Act + EnterpriseRequest.send_request("GET", "/test") + + # Assert + mock_httpx_client.request.assert_called_once() + call_args = mock_httpx_client.request.call_args + headers = call_args[1]["headers"] + + assert "traceparent" in headers + assert headers["traceparent"] == expected_traceparent + assert headers["Content-Type"] == "application/json" + assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key" From 91da784f84f544c840cd645addff2b988e4d648c Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Tue, 13 Jan 2026 22:38:28 +0800 Subject: [PATCH 20/29] refactor: init orpc contract (#30885) Co-authored-by: yyh --- .../access-control.spec.tsx | 7 - .../chat/chat-with-history/hooks.spec.tsx | 8 +- .../plans/cloud-plan-item/index.spec.tsx | 18 ++- .../pricing/plans/cloud-plan-item/index.tsx | 5 +- .../hooks/use-marketplace-all-plugins.ts | 10 +- .../model-provider-page/hooks.ts | 10 +- .../components/plugins/marketplace/hooks.ts | 24 +-- .../plugins/marketplace/hydration-server.tsx | 4 +- .../plugins/marketplace/index.spec.tsx | 141 ++++++++++-------- .../components/plugins/marketplace/query.ts | 27 ++-- .../components/plugins/marketplace/state.ts | 4 +- .../components/plugins/marketplace/types.ts | 6 +- .../components/plugins/marketplace/utils.ts | 87 ++++------- web/app/install/installForm.spec.tsx | 1 - web/context/global-public-context.tsx | 4 +- web/contract/base.ts | 3 + web/contract/console.ts | 34 +++++ web/contract/marketplace.ts | 56 +++++++ web/contract/router.ts | 19 +++ web/hooks/use-document-title.spec.ts | 4 - web/package.json | 4 + web/pnpm-lock.yaml | 137 ++++++++++++++++- web/service/base.ts | 5 + web/service/billing.ts | 16 +- web/service/client.ts | 61 ++++++++ web/service/common.ts | 5 - web/service/fetch.ts | 10 +- web/service/use-billing.ts | 15 +- web/service/use-plugins.ts | 24 +-- 29 files changed, 520 insertions(+), 229 deletions(-) create mode 100644 web/contract/base.ts create mode 100644 web/contract/console.ts create mode 100644 web/contract/marketplace.ts create mode 100644 web/contract/router.ts create mode 100644 web/service/client.ts diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index 0624cb316b..b73ed5c266 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -34,13 +34,6 @@ vi.mock('@/context/app-context', () => ({ }), })) -vi.mock('@/service/common', () => ({ - fetchCurrentWorkspace: vi.fn(), - fetchLangGeniusVersion: vi.fn(), - fetchUserProfile: vi.fn(), - getSystemFeatures: vi.fn(), -})) - vi.mock('@/service/access-control', () => ({ useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), diff --git a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx index a6d51d8643..f6a8f25cbb 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx @@ -170,8 +170,12 @@ describe('useChatWithHistory', () => { await waitFor(() => { expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1') }) - expect(result.current.pinnedConversationList).toEqual(pinnedData.data) - expect(result.current.conversationList).toEqual(listData.data) + await waitFor(() => { + expect(result.current.pinnedConversationList).toEqual(pinnedData.data) + }) + await waitFor(() => { + expect(result.current.conversationList).toEqual(listData.data) + }) }) }) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx index 4473ef98fa..680243a474 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx @@ -3,7 +3,8 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { useAppContext } from '@/context/app-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' -import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' +import { fetchSubscriptionUrls } from '@/service/billing' +import { consoleClient } from '@/service/client' import Toast from '../../../../base/toast' import { ALL_PLANS } from '../../../config' import { Plan } from '../../../type' @@ -21,10 +22,15 @@ vi.mock('@/context/app-context', () => ({ })) vi.mock('@/service/billing', () => ({ - fetchBillingUrl: vi.fn(), fetchSubscriptionUrls: vi.fn(), })) +vi.mock('@/service/client', () => ({ + consoleClient: { + billingUrl: vi.fn(), + }, +})) + vi.mock('@/hooks/use-async-window-open', () => ({ useAsyncWindowOpen: vi.fn(), })) @@ -37,7 +43,7 @@ vi.mock('../../assets', () => ({ const mockUseAppContext = useAppContext as Mock const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock -const mockFetchBillingUrl = fetchBillingUrl as Mock +const mockBillingUrl = consoleClient.billingUrl as Mock const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock const mockToastNotify = Toast.notify as Mock @@ -69,7 +75,7 @@ beforeEach(() => { vi.clearAllMocks() mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true }) mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open())) - mockFetchBillingUrl.mockResolvedValue({ url: 'https://billing.example' }) + mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' }) mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' }) assignedHref = '' }) @@ -143,7 +149,7 @@ describe('CloudPlanItem', () => { type: 'error', message: 'billing.buyPermissionDeniedTip', })) - expect(mockFetchBillingUrl).not.toHaveBeenCalled() + expect(mockBillingUrl).not.toHaveBeenCalled() }) it('should open billing portal when upgrading current paid plan', async () => { @@ -162,7 +168,7 @@ describe('CloudPlanItem', () => { fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' })) await waitFor(() => { - expect(mockFetchBillingUrl).toHaveBeenCalledTimes(1) + expect(mockBillingUrl).toHaveBeenCalledTimes(1) }) expect(openWindow).toHaveBeenCalledTimes(1) }) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index b694dc57e2..d9c4d3f75b 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -6,7 +6,8 @@ import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useAppContext } from '@/context/app-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' -import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' +import { fetchSubscriptionUrls } from '@/service/billing' +import { consoleClient } from '@/service/client' import Toast from '../../../../base/toast' import { ALL_PLANS } from '../../../config' import { Plan } from '../../../type' @@ -76,7 +77,7 @@ const CloudPlanItem: FC = ({ try { if (isCurrentPaidPlan) { await openAsyncWindow(async () => { - const res = await fetchBillingUrl() + const res = await consoleClient.billingUrl() if (res.url) return res.url throw new Error('Failed to open billing page') diff --git a/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts b/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts index 0c2154210c..90ef6e78a4 100644 --- a/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts +++ b/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts @@ -30,8 +30,8 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) = category: PluginCategoryEnum.datasource, exclude, type: 'plugin', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } else { @@ -39,10 +39,10 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) = query: '', category: PluginCategoryEnum.datasource, type: 'plugin', - pageSize: 1000, + page_size: 1000, exclude, - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude]) diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 0e35f0fb31..6aba41d4e4 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -275,8 +275,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: category: PluginCategoryEnum.model, exclude, type: 'plugin', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } else { @@ -284,10 +284,10 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: query: '', category: PluginCategoryEnum.model, type: 'plugin', - pageSize: 1000, + page_size: 1000, exclude, - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude]) diff --git a/web/app/components/plugins/marketplace/hooks.ts b/web/app/components/plugins/marketplace/hooks.ts index b1e4f50767..60ba0e0bee 100644 --- a/web/app/components/plugins/marketplace/hooks.ts +++ b/web/app/components/plugins/marketplace/hooks.ts @@ -100,11 +100,11 @@ export const useMarketplacePlugins = () => { const [queryParams, setQueryParams] = useState() const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => { - const pageSize = pluginsSearchParams.pageSize || 40 + const page_size = pluginsSearchParams.page_size || 40 return { ...pluginsSearchParams, - pageSize, + page_size, } }, []) @@ -116,20 +116,20 @@ export const useMarketplacePlugins = () => { plugins: [] as Plugin[], total: 0, page: 1, - pageSize: 40, + page_size: 40, } } const params = normalizeParams(queryParams) const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, exclude, type, - pageSize, + page_size, } = params const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' @@ -137,10 +137,10 @@ export const useMarketplacePlugins = () => { const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { body: { page: pageParam, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, exclude, @@ -154,7 +154,7 @@ export const useMarketplacePlugins = () => { plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)), total: res.data.total, page: pageParam, - pageSize, + page_size, } } catch { @@ -162,13 +162,13 @@ export const useMarketplacePlugins = () => { plugins: [], total: 0, page: pageParam, - pageSize, + page_size, } } }, getNextPageParam: (lastPage) => { const nextPage = lastPage.page + 1 - const loaded = lastPage.page * lastPage.pageSize + const loaded = lastPage.page * lastPage.page_size return loaded < (lastPage.total || 0) ? nextPage : undefined }, initialPageParam: 1, diff --git a/web/app/components/plugins/marketplace/hydration-server.tsx b/web/app/components/plugins/marketplace/hydration-server.tsx index 0aa544cff1..b01f4dd463 100644 --- a/web/app/components/plugins/marketplace/hydration-server.tsx +++ b/web/app/components/plugins/marketplace/hydration-server.tsx @@ -2,8 +2,8 @@ import type { SearchParams } from 'nuqs' import { dehydrate, HydrationBoundary } from '@tanstack/react-query' import { createLoader } from 'nuqs/server' import { getQueryClientServer } from '@/context/query-client-server' +import { marketplaceQuery } from '@/service/client' import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants' -import { marketplaceKeys } from './query' import { marketplaceSearchParamsParsers } from './search-params' import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils' @@ -23,7 +23,7 @@ async function getDehydratedState(searchParams?: Promise) { const queryClient = getQueryClientServer() await queryClient.prefetchQuery({ - queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)), + queryKey: marketplaceQuery.collections.queryKey({ input: { query: getCollectionsParams(params.category) } }), queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)), }) return dehydrate(queryClient) diff --git a/web/app/components/plugins/marketplace/index.spec.tsx b/web/app/components/plugins/marketplace/index.spec.tsx index 1a3cd15b6b..dc2513ac05 100644 --- a/web/app/components/plugins/marketplace/index.spec.tsx +++ b/web/app/components/plugins/marketplace/index.spec.tsx @@ -60,10 +60,10 @@ vi.mock('@/service/use-plugins', () => ({ // Mock tanstack query const mockFetchNextPage = vi.fn() const mockHasNextPage = false -let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, pageSize: number }> } | undefined +let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, page_size: number }> } | undefined let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise) | null = null let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise) | null = null -let capturedGetNextPageParam: ((lastPage: { page: number, pageSize: number, total: number }) => number | undefined) | null = null +let capturedGetNextPageParam: ((lastPage: { page: number, page_size: number, total: number }) => number | undefined) | null = null vi.mock('@tanstack/react-query', () => ({ useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise, enabled: boolean }) => { @@ -83,7 +83,7 @@ vi.mock('@tanstack/react-query', () => ({ }), useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: { queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise - getNextPageParam: (lastPage: { page: number, pageSize: number, total: number }) => number | undefined + getNextPageParam: (lastPage: { page: number, page_size: number, total: number }) => number | undefined enabled: boolean }) => { // Capture queryFn and getNextPageParam for later testing @@ -97,9 +97,9 @@ vi.mock('@tanstack/react-query', () => ({ // Call getNextPageParam to increase coverage if (getNextPageParam) { // Test with more data available - getNextPageParam({ page: 1, pageSize: 40, total: 100 }) + getNextPageParam({ page: 1, page_size: 40, total: 100 }) // Test with no more data - getNextPageParam({ page: 3, pageSize: 40, total: 100 }) + getNextPageParam({ page: 3, page_size: 40, total: 100 }) } return { data: mockInfiniteQueryData, @@ -151,6 +151,7 @@ vi.mock('@/service/base', () => ({ // Mock config vi.mock('@/config', () => ({ + API_PREFIX: '/api', APP_VERSION: '1.0.0', IS_MARKETPLACE: false, MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1', @@ -731,10 +732,10 @@ describe('useMarketplacePlugins', () => { expect(() => { result.current.queryPlugins({ query: 'test', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', category: 'tool', - pageSize: 20, + page_size: 20, }) }).not.toThrow() }) @@ -747,7 +748,7 @@ describe('useMarketplacePlugins', () => { result.current.queryPlugins({ query: 'test', type: 'bundle', - pageSize: 40, + page_size: 40, }) }).not.toThrow() }) @@ -798,8 +799,8 @@ describe('useMarketplacePlugins', () => { result.current.queryPlugins({ query: 'test', category: 'all', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) }).not.toThrow() }) @@ -824,7 +825,7 @@ describe('useMarketplacePlugins', () => { expect(() => { result.current.queryPlugins({ query: 'test', - pageSize: 100, + page_size: 100, }) }).not.toThrow() }) @@ -843,7 +844,7 @@ describe('Hooks queryFn Coverage', () => { // Set mock data to have pages mockInfiniteQueryData = { pages: [ - { plugins: [{ name: 'plugin1' }], total: 10, page: 1, pageSize: 40 }, + { plugins: [{ name: 'plugin1' }], total: 10, page: 1, page_size: 40 }, ], } @@ -863,8 +864,8 @@ describe('Hooks queryFn Coverage', () => { it('should expose page and total from infinite query data', async () => { mockInfiniteQueryData = { pages: [ - { plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, pageSize: 40 }, - { plugins: [{ name: 'plugin3' }], total: 20, page: 2, pageSize: 40 }, + { plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, page_size: 40 }, + { plugins: [{ name: 'plugin3' }], total: 20, page: 2, page_size: 40 }, ], } @@ -893,7 +894,7 @@ describe('Hooks queryFn Coverage', () => { it('should return total from first page when query is set and data exists', async () => { mockInfiniteQueryData = { pages: [ - { plugins: [], total: 50, page: 1, pageSize: 40 }, + { plugins: [], total: 50, page: 1, page_size: 40 }, ], } @@ -917,8 +918,8 @@ describe('Hooks queryFn Coverage', () => { type: 'plugin', query: 'search test', category: 'model', - sortBy: 'version_updated_at', - sortOrder: 'ASC', + sort_by: 'version_updated_at', + sort_order: 'ASC', }) expect(result.current).toBeDefined() @@ -1027,13 +1028,13 @@ describe('Advanced Hook Integration', () => { // Test with all possible parameters result.current.queryPlugins({ query: 'comprehensive test', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', category: 'tool', tags: ['tag1', 'tag2'], exclude: ['excluded-plugin'], type: 'plugin', - pageSize: 50, + page_size: 50, }) expect(result.current).toBeDefined() @@ -1081,9 +1082,9 @@ describe('Direct queryFn Coverage', () => { result.current.queryPlugins({ query: 'direct test', category: 'tool', - sortBy: 'install_count', - sortOrder: 'DESC', - pageSize: 40, + sort_by: 'install_count', + sort_order: 'DESC', + page_size: 40, }) // Now queryFn should be captured and enabled @@ -1255,7 +1256,7 @@ describe('Direct queryFn Coverage', () => { result.current.queryPlugins({ query: 'structure test', - pageSize: 20, + page_size: 20, }) if (capturedInfiniteQueryFn) { @@ -1264,14 +1265,14 @@ describe('Direct queryFn Coverage', () => { plugins: unknown[] total: number page: number - pageSize: number + page_size: number } // Verify the returned structure expect(response).toHaveProperty('plugins') expect(response).toHaveProperty('total') expect(response).toHaveProperty('page') - expect(response).toHaveProperty('pageSize') + expect(response).toHaveProperty('page_size') } }) }) @@ -1296,7 +1297,7 @@ describe('flatMap Coverage', () => { ], total: 5, page: 1, - pageSize: 40, + page_size: 40, }, { plugins: [ @@ -1304,7 +1305,7 @@ describe('flatMap Coverage', () => { ], total: 5, page: 2, - pageSize: 40, + page_size: 40, }, ], } @@ -1336,8 +1337,8 @@ describe('flatMap Coverage', () => { it('should test hook with pages data for flatMap path', async () => { mockInfiniteQueryData = { pages: [ - { plugins: [], total: 100, page: 1, pageSize: 40 }, - { plugins: [], total: 100, page: 2, pageSize: 40 }, + { plugins: [], total: 100, page: 1, page_size: 40 }, + { plugins: [], total: 100, page: 2, page_size: 40 }, ], } @@ -1371,7 +1372,7 @@ describe('flatMap Coverage', () => { plugins: unknown[] total: number page: number - pageSize: number + page_size: number } // When error is caught, should return fallback data expect(response.plugins).toEqual([]) @@ -1392,15 +1393,15 @@ describe('flatMap Coverage', () => { // Test getNextPageParam function directly if (capturedGetNextPageParam) { // When there are more pages - const nextPage = capturedGetNextPageParam({ page: 1, pageSize: 40, total: 100 }) + const nextPage = capturedGetNextPageParam({ page: 1, page_size: 40, total: 100 }) expect(nextPage).toBe(2) // When all data is loaded - const noMorePages = capturedGetNextPageParam({ page: 3, pageSize: 40, total: 100 }) + const noMorePages = capturedGetNextPageParam({ page: 3, page_size: 40, total: 100 }) expect(noMorePages).toBeUndefined() // Edge case: exactly at boundary - const atBoundary = capturedGetNextPageParam({ page: 2, pageSize: 50, total: 100 }) + const atBoundary = capturedGetNextPageParam({ page: 2, page_size: 50, total: 100 }) expect(atBoundary).toBeUndefined() } }) @@ -1427,7 +1428,7 @@ describe('flatMap Coverage', () => { plugins: unknown[] total: number page: number - pageSize: number + page_size: number } // Catch block should return fallback values expect(response.plugins).toEqual([]) @@ -1446,7 +1447,7 @@ describe('flatMap Coverage', () => { plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }], total: 10, page: 1, - pageSize: 40, + page_size: 40, }, ], } @@ -1489,9 +1490,12 @@ describe('Async Utils', () => { { type: 'plugin', org: 'test', name: 'plugin2' }, ] - globalThis.fetch = vi.fn().mockResolvedValue({ - json: () => Promise.resolve({ data: { plugins: mockPlugins } }), - }) + globalThis.fetch = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection', { @@ -1514,19 +1518,26 @@ describe('Async Utils', () => { }) it('should pass abort signal when provided', async () => { - const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }] - globalThis.fetch = vi.fn().mockResolvedValue({ - json: () => Promise.resolve({ data: { plugins: mockPlugins } }), - }) + const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] + globalThis.fetch = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) const controller = new AbortController() const { getMarketplacePluginsByCollectionId } = await import('./utils') await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal }) + // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL expect(globalThis.fetch).toHaveBeenCalledWith( - expect.any(String), - expect.objectContaining({ signal: controller.signal }), + expect.any(Request), + expect.any(Object), ) + const call = vi.mocked(globalThis.fetch).mock.calls[0] + const request = call[0] as Request + expect(request.url).toContain('test-collection') }) }) @@ -1535,19 +1546,25 @@ describe('Async Utils', () => { const mockCollections = [ { name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' }, ] - const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }] + const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] let callCount = 0 globalThis.fetch = vi.fn().mockImplementation(() => { callCount++ if (callCount === 1) { - return Promise.resolve({ - json: () => Promise.resolve({ data: { collections: mockCollections } }), - }) + return Promise.resolve( + new Response(JSON.stringify({ data: { collections: mockCollections } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) } - return Promise.resolve({ - json: () => Promise.resolve({ data: { plugins: mockPlugins } }), - }) + return Promise.resolve( + new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) }) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') @@ -1571,9 +1588,12 @@ describe('Async Utils', () => { }) it('should append condition and type to URL when provided', async () => { - globalThis.fetch = vi.fn().mockResolvedValue({ - json: () => Promise.resolve({ data: { collections: [] } }), - }) + globalThis.fetch = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ data: { collections: [] } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') await getMarketplaceCollectionsAndPlugins({ @@ -1581,10 +1601,11 @@ describe('Async Utils', () => { type: 'bundle', }) - expect(globalThis.fetch).toHaveBeenCalledWith( - expect.stringContaining('condition=category=tool'), - expect.any(Object), - ) + // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL + expect(globalThis.fetch).toHaveBeenCalled() + const call = vi.mocked(globalThis.fetch).mock.calls[0] + const request = call[0] as Request + expect(request.url).toContain('condition=category%3Dtool') }) }) }) diff --git a/web/app/components/plugins/marketplace/query.ts b/web/app/components/plugins/marketplace/query.ts index c5a1421146..35d99a2bd5 100644 --- a/web/app/components/plugins/marketplace/query.ts +++ b/web/app/components/plugins/marketplace/query.ts @@ -1,22 +1,14 @@ -import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types' +import type { PluginsSearchParams } from './types' +import type { MarketPlaceInputs } from '@/contract/router' import { useInfiniteQuery, useQuery } from '@tanstack/react-query' +import { marketplaceQuery } from '@/service/client' import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils' -// TODO: Avoid manual maintenance of query keys and better service management, -// https://github.com/langgenius/dify/issues/30342 - -export const marketplaceKeys = { - all: ['marketplace'] as const, - collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const, - collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const, - plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const, -} - export function useMarketplaceCollectionsAndPlugins( - collectionsParams: CollectionsAndPluginsSearchParams, + collectionsParams: MarketPlaceInputs['collections']['query'], ) { return useQuery({ - queryKey: marketplaceKeys.collections(collectionsParams), + queryKey: marketplaceQuery.collections.queryKey({ input: { query: collectionsParams } }), queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }), }) } @@ -25,11 +17,16 @@ export function useMarketplacePlugins( queryParams: PluginsSearchParams | undefined, ) { return useInfiniteQuery({ - queryKey: marketplaceKeys.plugins(queryParams), + queryKey: marketplaceQuery.searchAdvanced.queryKey({ + input: { + body: queryParams!, + params: { kind: queryParams?.type === 'bundle' ? 'bundles' : 'plugins' }, + }, + }), queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal), getNextPageParam: (lastPage) => { const nextPage = lastPage.page + 1 - const loaded = lastPage.page * lastPage.pageSize + const loaded = lastPage.page * lastPage.page_size return loaded < (lastPage.total || 0) ? nextPage : undefined }, initialPageParam: 1, diff --git a/web/app/components/plugins/marketplace/state.ts b/web/app/components/plugins/marketplace/state.ts index 1c1abfc0a1..9c76a21e92 100644 --- a/web/app/components/plugins/marketplace/state.ts +++ b/web/app/components/plugins/marketplace/state.ts @@ -26,8 +26,8 @@ export function useMarketplaceData() { query: searchPluginText, category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType, tags: filterPluginTags, - sortBy: sort.sortBy, - sortOrder: sort.sortOrder, + sort_by: sort.sortBy, + sort_order: sort.sortOrder, type: getMarketplaceListFilterType(activePluginType), } }, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort]) diff --git a/web/app/components/plugins/marketplace/types.ts b/web/app/components/plugins/marketplace/types.ts index 4145f69248..e4e2dbd935 100644 --- a/web/app/components/plugins/marketplace/types.ts +++ b/web/app/components/plugins/marketplace/types.ts @@ -30,9 +30,9 @@ export type MarketplaceCollectionPluginsResponse = { export type PluginsSearchParams = { query: string page?: number - pageSize?: number - sortBy?: string - sortOrder?: string + page_size?: number + sort_by?: string + sort_order?: string category?: string tags?: string[] exclude?: string[] diff --git a/web/app/components/plugins/marketplace/utils.ts b/web/app/components/plugins/marketplace/utils.ts index eaf299314c..01f3c59284 100644 --- a/web/app/components/plugins/marketplace/utils.ts +++ b/web/app/components/plugins/marketplace/utils.ts @@ -4,14 +4,12 @@ import type { MarketplaceCollection, PluginsSearchParams, } from '@/app/components/plugins/marketplace/types' -import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types' +import type { Plugin } from '@/app/components/plugins/types' import { PluginCategoryEnum } from '@/app/components/plugins/types' import { - APP_VERSION, - IS_MARKETPLACE, MARKETPLACE_API_PREFIX, } from '@/config' -import { postMarketplace } from '@/service/base' +import { marketplaceClient } from '@/service/client' import { getMarketplaceUrl } from '@/utils/var' import { PLUGIN_TYPE_SEARCH_MAP } from './constants' @@ -19,10 +17,6 @@ type MarketplaceFetchOptions = { signal?: AbortSignal } -const getMarketplaceHeaders = () => new Headers({ - 'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0', -}) - export const getPluginIconInMarketplace = (plugin: Plugin) => { if (plugin.type === 'bundle') return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon` @@ -65,24 +59,15 @@ export const getMarketplacePluginsByCollectionId = async ( let plugins: Plugin[] = [] try { - const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins` - const headers = getMarketplaceHeaders() - const marketplaceCollectionPluginsData = await globalThis.fetch( - url, - { - cache: 'no-store', - method: 'POST', - headers, - signal: options?.signal, - body: JSON.stringify({ - category: query?.category, - exclude: query?.exclude, - type: query?.type, - }), + const marketplaceCollectionPluginsDataJson = await marketplaceClient.collectionPlugins({ + params: { + collectionId, }, - ) - const marketplaceCollectionPluginsDataJson = await marketplaceCollectionPluginsData.json() - plugins = (marketplaceCollectionPluginsDataJson.data.plugins || []).map((plugin: Plugin) => getFormattedPlugin(plugin)) + body: query, + }, { + signal: options?.signal, + }) + plugins = (marketplaceCollectionPluginsDataJson.data?.plugins || []).map(plugin => getFormattedPlugin(plugin)) } // eslint-disable-next-line unused-imports/no-unused-vars catch (e) { @@ -99,22 +84,16 @@ export const getMarketplaceCollectionsAndPlugins = async ( let marketplaceCollections: MarketplaceCollection[] = [] let marketplaceCollectionPluginsMap: Record = {} try { - let marketplaceUrl = `${MARKETPLACE_API_PREFIX}/collections?page=1&page_size=100` - if (query?.condition) - marketplaceUrl += `&condition=${query.condition}` - if (query?.type) - marketplaceUrl += `&type=${query.type}` - const headers = getMarketplaceHeaders() - const marketplaceCollectionsData = await globalThis.fetch( - marketplaceUrl, - { - headers, - cache: 'no-store', - signal: options?.signal, + const marketplaceCollectionsDataJson = await marketplaceClient.collections({ + query: { + ...query, + page: 1, + page_size: 100, }, - ) - const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json() - marketplaceCollections = marketplaceCollectionsDataJson.data.collections || [] + }, { + signal: options?.signal, + }) + marketplaceCollections = marketplaceCollectionsDataJson.data?.collections || [] await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => { const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options) @@ -143,42 +122,42 @@ export const getMarketplacePlugins = async ( plugins: [] as Plugin[], total: 0, page: 1, - pageSize: 40, + page_size: 40, } } const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, type, - pageSize = 40, + page_size = 40, } = queryParams - const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' try { - const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { + const res = await marketplaceClient.searchAdvanced({ + params: { + kind: type === 'bundle' ? 'bundles' : 'plugins', + }, body: { page: pageParam, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, - type, }, - signal, - }) + }, { signal }) const resPlugins = res.data.bundles || res.data.plugins || [] return { plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)), total: res.data.total, page: pageParam, - pageSize, + page_size, } } catch { @@ -186,7 +165,7 @@ export const getMarketplacePlugins = async ( plugins: [], total: 0, page: pageParam, - pageSize, + page_size, } } } diff --git a/web/app/install/installForm.spec.tsx b/web/app/install/installForm.spec.tsx index 5efd5cebb6..17ce35d6a1 100644 --- a/web/app/install/installForm.spec.tsx +++ b/web/app/install/installForm.spec.tsx @@ -16,7 +16,6 @@ vi.mock('@/service/common', () => ({ fetchInitValidateStatus: vi.fn(), setup: vi.fn(), login: vi.fn(), - getSystemFeatures: vi.fn(), })) vi.mock('@/context/global-public-context', async (importOriginal) => { diff --git a/web/context/global-public-context.tsx b/web/context/global-public-context.tsx index 9b2b0834e2..3a570fc7ef 100644 --- a/web/context/global-public-context.tsx +++ b/web/context/global-public-context.tsx @@ -4,7 +4,7 @@ import type { SystemFeatures } from '@/types/feature' import { useQuery } from '@tanstack/react-query' import { create } from 'zustand' import Loading from '@/app/components/base/loading' -import { getSystemFeatures } from '@/service/common' +import { consoleClient } from '@/service/client' import { defaultSystemFeatures } from '@/types/feature' import { fetchSetupStatusWithCache } from '@/utils/setup-status' @@ -22,7 +22,7 @@ const systemFeaturesQueryKey = ['systemFeatures'] as const const setupStatusQueryKey = ['setupStatus'] as const async function fetchSystemFeatures() { - const data = await getSystemFeatures() + const data = await consoleClient.systemFeatures() const { setSystemFeatures } = useGlobalPublicStore.getState() setSystemFeatures({ ...defaultSystemFeatures, ...data }) return data diff --git a/web/contract/base.ts b/web/contract/base.ts new file mode 100644 index 0000000000..764db9d554 --- /dev/null +++ b/web/contract/base.ts @@ -0,0 +1,3 @@ +import { oc } from '@orpc/contract' + +export const base = oc.$route({ inputStructure: 'detailed' }) diff --git a/web/contract/console.ts b/web/contract/console.ts new file mode 100644 index 0000000000..ec929d1357 --- /dev/null +++ b/web/contract/console.ts @@ -0,0 +1,34 @@ +import type { SystemFeatures } from '@/types/feature' +import { type } from '@orpc/contract' +import { base } from './base' + +export const systemFeaturesContract = base + .route({ + path: '/system-features', + method: 'GET', + }) + .input(type()) + .output(type()) + +export const billingUrlContract = base + .route({ + path: '/billing/invoices', + method: 'GET', + }) + .input(type()) + .output(type<{ url: string }>()) + +export const bindPartnerStackContract = base + .route({ + path: '/billing/partners/{partnerKey}/tenants', + method: 'PUT', + }) + .input(type<{ + params: { + partnerKey: string + } + body: { + click_id: string + } + }>()) + .output(type()) diff --git a/web/contract/marketplace.ts b/web/contract/marketplace.ts new file mode 100644 index 0000000000..3573ba5c24 --- /dev/null +++ b/web/contract/marketplace.ts @@ -0,0 +1,56 @@ +import type { CollectionsAndPluginsSearchParams, MarketplaceCollection, PluginsSearchParams } from '@/app/components/plugins/marketplace/types' +import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types' +import { type } from '@orpc/contract' +import { base } from './base' + +export const collectionsContract = base + .route({ + path: '/collections', + method: 'GET', + }) + .input( + type<{ + query?: CollectionsAndPluginsSearchParams & { page?: number, page_size?: number } + }>(), + ) + .output( + type<{ + data?: { + collections?: MarketplaceCollection[] + } + }>(), + ) + +export const collectionPluginsContract = base + .route({ + path: '/collections/{collectionId}/plugins', + method: 'POST', + }) + .input( + type<{ + params: { + collectionId: string + } + body?: CollectionsAndPluginsSearchParams + }>(), + ) + .output( + type<{ + data?: { + plugins?: Plugin[] + } + }>(), + ) + +export const searchAdvancedContract = base + .route({ + path: '/{kind}/search/advanced', + method: 'POST', + }) + .input(type<{ + params: { + kind: 'plugins' | 'bundles' + } + body: Omit + }>()) + .output(type<{ data: PluginsFromMarketplaceResponse }>()) diff --git a/web/contract/router.ts b/web/contract/router.ts new file mode 100644 index 0000000000..d83cffb7b8 --- /dev/null +++ b/web/contract/router.ts @@ -0,0 +1,19 @@ +import type { InferContractRouterInputs } from '@orpc/contract' +import { billingUrlContract, bindPartnerStackContract, systemFeaturesContract } from './console' +import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace' + +export const marketplaceRouterContract = { + collections: collectionsContract, + collectionPlugins: collectionPluginsContract, + searchAdvanced: searchAdvancedContract, +} + +export type MarketPlaceInputs = InferContractRouterInputs + +export const consoleRouterContract = { + systemFeatures: systemFeaturesContract, + billingUrl: billingUrlContract, + bindPartnerStack: bindPartnerStackContract, +} + +export type ConsoleInputs = InferContractRouterInputs diff --git a/web/hooks/use-document-title.spec.ts b/web/hooks/use-document-title.spec.ts index efa72cac5c..7ce1e693db 100644 --- a/web/hooks/use-document-title.spec.ts +++ b/web/hooks/use-document-title.spec.ts @@ -23,10 +23,6 @@ vi.mock('@/context/global-public-context', async (importOriginal) => { } }) -vi.mock('@/service/common', () => ({ - getSystemFeatures: vi.fn(() => ({ ...defaultSystemFeatures })), -})) - /** * Test behavior when system features are still loading * Title should remain empty to prevent flicker diff --git a/web/package.json b/web/package.json index 44cc9196f4..fab33f7608 100644 --- a/web/package.json +++ b/web/package.json @@ -69,6 +69,10 @@ "@monaco-editor/react": "^4.7.0", "@octokit/core": "^6.1.6", "@octokit/request-error": "^6.1.8", + "@orpc/client": "^1.13.4", + "@orpc/contract": "^1.13.4", + "@orpc/openapi-client": "^1.13.4", + "@orpc/tanstack-query": "^1.13.4", "@remixicon/react": "^4.7.0", "@sentry/react": "^8.55.0", "@svgdotjs/svg.js": "^3.2.5", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 853c366025..c8797e3d65 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -108,6 +108,18 @@ importers: '@octokit/request-error': specifier: ^6.1.8 version: 6.1.8 + '@orpc/client': + specifier: ^1.13.4 + version: 1.13.4 + '@orpc/contract': + specifier: ^1.13.4 + version: 1.13.4 + '@orpc/openapi-client': + specifier: ^1.13.4 + version: 1.13.4 + '@orpc/tanstack-query': + specifier: ^1.13.4 + version: 1.13.4(@orpc/client@1.13.4)(@tanstack/query-core@5.90.12) '@remixicon/react': specifier: ^4.7.0 version: 4.7.0(react@19.2.3) @@ -2291,6 +2303,38 @@ packages: '@open-draft/until@2.1.0': resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==} + '@orpc/client@1.13.4': + resolution: {integrity: sha512-s13GPMeoooJc5Th2EaYT5HMFtWG8S03DUVytYfJv8pIhP87RYKl94w52A36denH6r/B4LaAgBeC9nTAOslK+Og==} + + '@orpc/contract@1.13.4': + resolution: {integrity: sha512-TIxyaF67uOlihCRcasjHZxguZpbqfNK7aMrDLnhoufmQBE4OKvguNzmrOFHgsuM0OXoopX0Nuhun1ccaxKP10A==} + + '@orpc/openapi-client@1.13.4': + resolution: {integrity: sha512-tRUcY4E6sgpS5bY/9nNES/Q/PMyYyPOsI4TuhwLhfgxOb0GFPwYKJ6Kif7KFNOhx4fkN/jTOfE1nuWuIZU1gyg==} + + '@orpc/shared@1.13.4': + resolution: {integrity: sha512-TYt9rLG/BUkNQBeQ6C1tEiHS/Seb8OojHgj9GlvqyjHJhMZx5qjsIyTW6RqLPZJ4U2vgK6x4Her36+tlFCKJug==} + peerDependencies: + '@opentelemetry/api': '>=1.9.0' + peerDependenciesMeta: + '@opentelemetry/api': + optional: true + + '@orpc/standard-server-fetch@1.13.4': + resolution: {integrity: sha512-/zmKwnuxfAXbppJpgr1CMnQX3ptPlYcDzLz1TaVzz9VG/Xg58Ov3YhabS2Oi1utLVhy5t4kaCppUducAvoKN+A==} + + '@orpc/standard-server-peer@1.13.4': + resolution: {integrity: sha512-UfqnTLqevjCKUk4cmImOG8cQUwANpV1dp9e9u2O1ki6BRBsg/zlXFg6G2N6wP0zr9ayIiO1d2qJdH55yl/1BNw==} + + '@orpc/standard-server@1.13.4': + resolution: {integrity: sha512-ZOzgfVp6XUg+wVYw+gqesfRfGPtQbnBIrIiSnFMtZF+6ncmFJeF2Shc4RI2Guqc0Qz25juy8Ogo4tX3YqysOcg==} + + '@orpc/tanstack-query@1.13.4': + resolution: {integrity: sha512-gCL/kh3kf6OUGKfXxSoOZpcX1jNYzxGfo/PkLQKX7ui4xiTbfWw3sCDF30sNS4I7yAOnBwDwJ3N2xzfkTftOBg==} + peerDependencies: + '@orpc/client': 1.13.4 + '@tanstack/query-core': '>=5.80.2' + '@oxc-resolver/binding-android-arm-eabi@11.15.0': resolution: {integrity: sha512-Q+lWuFfq7whNelNJIP1dhXaVz4zO9Tu77GcQHyxDWh3MaCoO2Bisphgzmsh4ZoUe2zIchQh6OvQL99GlWHg9Tw==} cpu: [arm] @@ -6685,6 +6729,9 @@ packages: resolution: {integrity: sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==} engines: {node: '>=12'} + openapi-types@12.1.3: + resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} + opener@1.5.2: resolution: {integrity: sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==} hasBin: true @@ -7081,6 +7128,10 @@ packages: queue-microtask@1.2.3: resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==} + radash@12.1.1: + resolution: {integrity: sha512-h36JMxKRqrAxVD8201FrCpyeNuUY9Y5zZwujr20fFO77tpUtGa6EZzfKw/3WaiBX95fq7+MpsuMLNdSnORAwSA==} + engines: {node: '>=14.18.0'} + randombytes@2.1.0: resolution: {integrity: sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==} @@ -7826,6 +7877,10 @@ packages: tabbable@6.3.0: resolution: {integrity: sha512-EIHvdY5bPLuWForiR/AN2Bxngzpuwn1is4asboytXtpTgsArc+WmSJKVLlhdh71u7jFcryDqB2A8lQvj78MkyQ==} + tagged-tag@1.0.0: + resolution: {integrity: sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng==} + engines: {node: '>=20'} + tailwind-merge@2.6.0: resolution: {integrity: sha512-P+Vu1qXfzediirmHOC3xKGAYeZtPcV9g76X+xg2FD4tYgR71ewMA35Y3sCz3zhiN/dwefRpJX0yBcgwi1fXNQA==} @@ -8027,13 +8082,17 @@ packages: resolution: {integrity: sha512-5zknd7Dss75pMSED270A1RQS3KloqRJA9XbXLe0eCxyw7xXFb3rd+9B0UQ/0E+LQT6lnrLviEolYORlRWamn4w==} engines: {node: '>=16'} + type-fest@5.4.0: + resolution: {integrity: sha512-wfkA6r0tBpVfGiyO+zbf9e10QkRQSlK9F2UvyfnjoCmrvH2bjHyhPzhugSBOuq1dog3P0+FKckqe+Xf6WKVjwg==} + engines: {node: '>=20'} + typescript@5.9.3: resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==} engines: {node: '>=14.17'} hasBin: true - ufo@1.6.1: - resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==} + ufo@1.6.2: + resolution: {integrity: sha512-heMioaxBcG9+Znsda5Q8sQbWnLJSl98AFDXTO80wELWEzX3hordXsTdxrIfMQoO9IY1MEnoGoPjpoKpMj+Yx0Q==} uglify-js@3.19.3: resolution: {integrity: sha512-v3Xu+yuwBXisp6QYTcH4UbH+xYJXqnq2m/LtQVWKWzYc1iehYnLixoQDN9FH6/j9/oybfd6W9Ghwkl8+UMKTKQ==} @@ -10638,6 +10697,66 @@ snapshots: '@open-draft/until@2.1.0': {} + '@orpc/client@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + '@orpc/standard-server-fetch': 1.13.4 + '@orpc/standard-server-peer': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/contract@1.13.4': + dependencies: + '@orpc/client': 1.13.4 + '@orpc/shared': 1.13.4 + '@standard-schema/spec': 1.1.0 + openapi-types: 12.1.3 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/openapi-client@1.13.4': + dependencies: + '@orpc/client': 1.13.4 + '@orpc/contract': 1.13.4 + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/shared@1.13.4': + dependencies: + radash: 12.1.1 + type-fest: 5.4.0 + + '@orpc/standard-server-fetch@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/standard-server-peer@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/standard-server@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/tanstack-query@1.13.4(@orpc/client@1.13.4)(@tanstack/query-core@5.90.12)': + dependencies: + '@orpc/client': 1.13.4 + '@orpc/shared': 1.13.4 + '@tanstack/query-core': 5.90.12 + transitivePeerDependencies: + - '@opentelemetry/api' + '@oxc-resolver/binding-android-arm-eabi@11.15.0': optional: true @@ -15603,7 +15722,7 @@ snapshots: acorn: 8.15.0 pathe: 2.0.3 pkg-types: 1.3.1 - ufo: 1.6.1 + ufo: 1.6.2 monaco-editor@0.55.1: dependencies: @@ -15766,6 +15885,8 @@ snapshots: is-docker: 2.2.1 is-wsl: 2.2.0 + openapi-types@12.1.3: {} + opener@1.5.2: {} optionator@0.9.4: @@ -16181,6 +16302,8 @@ snapshots: queue-microtask@1.2.3: {} + radash@12.1.1: {} + randombytes@2.1.0: dependencies: safe-buffer: 5.2.1 @@ -17098,6 +17221,8 @@ snapshots: tabbable@6.3.0: {} + tagged-tag@1.0.0: {} + tailwind-merge@2.6.0: {} tailwindcss@3.4.18(tsx@4.21.0)(yaml@2.8.2): @@ -17305,9 +17430,13 @@ snapshots: type-fest@4.2.0: optional: true + type-fest@5.4.0: + dependencies: + tagged-tag: 1.0.0 + typescript@5.9.3: {} - ufo@1.6.1: {} + ufo@1.6.2: {} uglify-js@3.19.3: {} diff --git a/web/service/base.ts b/web/service/base.ts index 2ab115f96c..fb32ce6bcf 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -81,6 +81,11 @@ export type IOtherOptions = { needAllResponseContent?: boolean deleteContentType?: boolean silent?: boolean + + /** If true, behaves like standard fetch: no URL prefix, returns raw Response */ + fetchCompat?: boolean + request?: Request + onData?: IOnData // for stream onThought?: IOnThought onFile?: IOnFile diff --git a/web/service/billing.ts b/web/service/billing.ts index f06c4f06c6..075ab71ade 100644 --- a/web/service/billing.ts +++ b/web/service/billing.ts @@ -1,5 +1,5 @@ import type { CurrentPlanInfoBackend, SubscriptionUrlsBackend } from '@/app/components/billing/type' -import { get, put } from './base' +import { get } from './base' export const fetchCurrentPlanInfo = () => { return get('/features') @@ -8,17 +8,3 @@ export const fetchCurrentPlanInfo = () => { export const fetchSubscriptionUrls = (plan: string, interval: string) => { return get(`/billing/subscription?plan=${plan}&interval=${interval}`) } - -export const fetchBillingUrl = () => { - return get<{ url: string }>('/billing/invoices') -} - -export const bindPartnerStackInfo = (partnerKey: string, clickId: string) => { - return put(`/billing/partners/${partnerKey}/tenants`, { - body: { - click_id: clickId, - }, - }, { - silent: true, - }) -} diff --git a/web/service/client.ts b/web/service/client.ts new file mode 100644 index 0000000000..c9c92ddd15 --- /dev/null +++ b/web/service/client.ts @@ -0,0 +1,61 @@ +import type { ContractRouterClient } from '@orpc/contract' +import type { JsonifiedClient } from '@orpc/openapi-client' +import { createORPCClient, onError } from '@orpc/client' +import { OpenAPILink } from '@orpc/openapi-client/fetch' +import { createTanstackQueryUtils } from '@orpc/tanstack-query' +import { + API_PREFIX, + APP_VERSION, + IS_MARKETPLACE, + MARKETPLACE_API_PREFIX, +} from '@/config' +import { + consoleRouterContract, + marketplaceRouterContract, +} from '@/contract/router' +import { request } from './base' + +const getMarketplaceHeaders = () => new Headers({ + 'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0', +}) + +const marketplaceLink = new OpenAPILink(marketplaceRouterContract, { + url: MARKETPLACE_API_PREFIX, + headers: () => (getMarketplaceHeaders()), + fetch: (request, init) => { + return globalThis.fetch(request, { + ...init, + cache: 'no-store', + }) + }, + interceptors: [ + onError((error) => { + console.error(error) + }), + ], +}) + +export const marketplaceClient: JsonifiedClient> = createORPCClient(marketplaceLink) +export const marketplaceQuery = createTanstackQueryUtils(marketplaceClient, { path: ['marketplace'] }) + +const consoleLink = new OpenAPILink(consoleRouterContract, { + url: API_PREFIX, + fetch: (input, init) => { + return request( + input.url, + init, + { + fetchCompat: true, + request: input, + }, + ) + }, + interceptors: [ + onError((error) => { + console.error(error) + }), + ], +}) + +export const consoleClient: JsonifiedClient> = createORPCClient(consoleLink) +export const consoleQuery = createTanstackQueryUtils(consoleClient, { path: ['console'] }) diff --git a/web/service/common.ts b/web/service/common.ts index 5fc4850d5f..70211d10d3 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -34,7 +34,6 @@ import type { UserProfileOriginResponse, } from '@/models/common' import type { RETRIEVE_METHOD } from '@/types/app' -import type { SystemFeatures } from '@/types/feature' import { del, get, patch, post, put } from './base' type LoginSuccess = { @@ -307,10 +306,6 @@ export const fetchSupportRetrievalMethods = (url: string): Promise(url) } -export const getSystemFeatures = (): Promise => { - return get('/system-features') -} - export const enableModel = (url: string, body: { model: string, model_type: ModelTypeEnum }): Promise => patch(url, { body }) diff --git a/web/service/fetch.ts b/web/service/fetch.ts index d0af932d73..13be7ae97b 100644 --- a/web/service/fetch.ts +++ b/web/service/fetch.ts @@ -136,6 +136,8 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: needAllResponseContent, deleteContentType, getAbortController, + fetchCompat = false, + request, } = otherOptions let base: string @@ -181,7 +183,7 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: }, }) - const res = await client(fetchPathname, { + const res = await client(request || fetchPathname, { ...init, headers, credentials: isMarketplaceAPI @@ -190,8 +192,8 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: retry: { methods: [], }, - ...(bodyStringify ? { json: body } : { body: body as BodyInit }), - searchParams: params, + ...(bodyStringify && !fetchCompat ? { json: body } : { body: body as BodyInit }), + searchParams: !fetchCompat ? params : undefined, fetch(resource: RequestInfo | URL, options?: RequestInit) { if (resource instanceof Request && options) { const mergedHeaders = new Headers(options.headers || {}) @@ -204,7 +206,7 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: }, }) - if (needAllResponseContent) + if (needAllResponseContent || fetchCompat) return res as T const contentType = res.headers.get('content-type') if ( diff --git a/web/service/use-billing.ts b/web/service/use-billing.ts index 3dc2b8a994..794b192d5c 100644 --- a/web/service/use-billing.ts +++ b/web/service/use-billing.ts @@ -1,21 +1,22 @@ import { useMutation, useQuery } from '@tanstack/react-query' -import { bindPartnerStackInfo, fetchBillingUrl } from '@/service/billing' - -const NAME_SPACE = 'billing' +import { consoleClient, consoleQuery } from '@/service/client' export const useBindPartnerStackInfo = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'bind-partner-stack'], - mutationFn: (data: { partnerKey: string, clickId: string }) => bindPartnerStackInfo(data.partnerKey, data.clickId), + mutationKey: consoleQuery.bindPartnerStack.mutationKey(), + mutationFn: (data: { partnerKey: string, clickId: string }) => consoleClient.bindPartnerStack({ + params: { partnerKey: data.partnerKey }, + body: { click_id: data.clickId }, + }), }) } export const useBillingUrl = (enabled: boolean) => { return useQuery({ - queryKey: [NAME_SPACE, 'url'], + queryKey: consoleQuery.billingUrl.queryKey(), enabled, queryFn: async () => { - const res = await fetchBillingUrl() + const res = await consoleClient.billingUrl() return res.url }, }) diff --git a/web/service/use-plugins.ts b/web/service/use-plugins.ts index 4e9776df97..5267503a11 100644 --- a/web/service/use-plugins.ts +++ b/web/service/use-plugins.ts @@ -488,23 +488,23 @@ export const useMutationPluginsFromMarketplace = () => { mutationFn: (pluginsSearchParams: PluginsSearchParams) => { const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, exclude, type, page = 1, - pageSize = 40, + page_size = 40, } = pluginsSearchParams const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' return postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { body: { page, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, exclude, @@ -535,23 +535,23 @@ export const useFetchPluginListOrBundleList = (pluginsSearchParams: PluginsSearc queryFn: () => { const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, exclude, type, page = 1, - pageSize = 40, + page_size = 40, } = pluginsSearchParams const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' return postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { body: { page, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, exclude, From 206706987d8d0c2972fc231bbb631f03f0f4c26f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 13 Jan 2026 22:39:34 +0800 Subject: [PATCH 21/29] refactor(variables): clarify base vs union type naming (#30634) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/app/apps/advanced_chat/app_runner.py | 10 +-- .../conversation_variable_persist_layer.py | 4 +- api/core/variables/__init__.py | 2 + api/core/variables/segments.py | 2 +- api/core/variables/variables.py | 28 ++++----- .../workflow/conversation_variable_updater.py | 6 +- .../graph_engine/entities/commands.py | 4 +- .../nodes/iteration/iteration_node.py | 10 +-- .../nodes/variable_assigner/v1/node.py | 4 +- .../nodes/variable_assigner/v2/node.py | 8 +-- api/core/workflow/runtime/variable_pool.py | 22 +++---- api/core/workflow/variable_loader.py | 8 +-- api/factories/variable_factory.py | 20 +++--- api/fields/workflow_fields.py | 4 +- api/models/workflow.py | 62 +++++++++---------- api/services/conversation_variable_updater.py | 4 +- api/services/rag_pipeline/rag_pipeline.py | 6 +- .../workflow_draft_variable_service.py | 14 ++--- api/services/workflow_service.py | 16 ++--- .../unit_tests/core/variables/test_segment.py | 5 +- .../core/variables/test_variables.py | 4 +- .../core/workflow/test_variable_pool.py | 6 +- 22 files changed, 124 insertions(+), 125 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d636548f2b..a258144d35 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): system_variables=system_inputs, user_inputs=inputs, environment_variables=self._workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. conversation_variables=conversation_variables, ) @@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): trace_manager=app_generate_entity.trace_manager, ) - def _initialize_conversation_variables(self) -> list[VariableUnion]: + def _initialize_conversation_variables(self) -> list[Variable]: """ Initialize conversation variables for the current conversation. @@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation_variables = [var.to_variable() for var in existing_variables] session.commit() - return cast(list[VariableUnion], conversation_variables) + return cast(list[Variable], conversation_variables) def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: """ diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index 77cc00bdc9..c070845b73 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,6 +1,6 @@ import logging -from core.variables import Variable +from core.variables import VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType @@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): if selector[0] != CONVERSATION_VARIABLE_NODE_ID: continue variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): logger.warning( "Conversation variable not found in variable pool. selector=%s", selector, diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 7a1cbf9940..7498224923 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -30,6 +30,7 @@ from .variables import ( SecretVariable, StringVariable, Variable, + VariableBase, ) __all__ = [ @@ -62,4 +63,5 @@ __all__ = [ "StringSegment", "StringVariable", "Variable", + "VariableBase", ] diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 406b4e6f93..8330f1fe19 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None: # - All variants in `SegmentUnion` must inherit from the `Segment` class. # - The union must include all non-abstract subclasses of `Segment`, except: # - `SegmentGroup`, which is not added to the variable pool. -# - `Variable` and its subclasses, which are handled by `VariableUnion`. +# - `VariableBase` and its subclasses, which are handled by `Variable`. SegmentUnion: TypeAlias = Annotated[ ( Annotated[NoneSegment, Tag(SegmentType.NONE)] diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 9fd0bbc5b2..a19c53918d 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -27,7 +27,7 @@ from .segments import ( from .types import SegmentType -class Variable(Segment): +class VariableBase(Segment): """ A variable is a segment that has a name. @@ -45,23 +45,23 @@ class Variable(Segment): selector: Sequence[str] = Field(default_factory=list) -class StringVariable(StringSegment, Variable): +class StringVariable(StringSegment, VariableBase): pass -class FloatVariable(FloatSegment, Variable): +class FloatVariable(FloatSegment, VariableBase): pass -class IntegerVariable(IntegerSegment, Variable): +class IntegerVariable(IntegerSegment, VariableBase): pass -class ObjectVariable(ObjectSegment, Variable): +class ObjectVariable(ObjectSegment, VariableBase): pass -class ArrayVariable(ArraySegment, Variable): +class ArrayVariable(ArraySegment, VariableBase): pass @@ -89,16 +89,16 @@ class SecretVariable(StringVariable): return encrypter.obfuscated_token(self.value) -class NoneVariable(NoneSegment, Variable): +class NoneVariable(NoneSegment, VariableBase): value_type: SegmentType = SegmentType.NONE value: None = None -class FileVariable(FileSegment, Variable): +class FileVariable(FileSegment, VariableBase): pass -class BooleanVariable(BooleanSegment, Variable): +class BooleanVariable(BooleanSegment, VariableBase): pass @@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel): value: Any -# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Variable` for type hinting when serialization is not required. +# The `Variable` type is used to enable serialization and deserialization with Pydantic. +# Use `VariableBase` for type hinting when serialization is not required. # # Note: -# - All variants in `VariableUnion` must inherit from the `Variable` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -VariableUnion: TypeAlias = Annotated[ +# - All variants in `Variable` must inherit from the `VariableBase` class. +# - The union must include all non-abstract subclasses of `VariableBase`. +Variable: TypeAlias = Annotated[ ( Annotated[NoneVariable, Tag(SegmentType.NONE)] | Annotated[StringVariable, Tag(SegmentType.STRING)] diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py index fd78248c17..75f47691da 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/core/workflow/conversation_variable_updater.py @@ -1,7 +1,7 @@ import abc from typing import Protocol -from core.variables import Variable +from core.variables import VariableBase class ConversationVariableUpdater(Protocol): @@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol): """ @abc.abstractmethod - def update(self, conversation_id: str, variable: "Variable"): + def update(self, conversation_id: str, variable: "VariableBase"): """ Updates the value of the specified conversation variable in the underlying storage. :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `Variable` instance containing the updated value. + :param variable: The `VariableBase` instance containing the updated value. """ pass diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 6dce03c94d..41276eb444 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.variables.variables import VariableUnion +from core.variables.variables import Variable class CommandType(StrEnum): @@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand): class VariableUpdate(BaseModel): """Represents a single variable update instruction.""" - value: VariableUnion = Field(description="New variable value") + value: Variable = Field(description="New variable value") class UpdateVariablesCommand(GraphEngineCommand): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e5d86414c1..91df2e4e0b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,7 +11,7 @@ from typing_extensions import TypeIs from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import ( NodeExecutionType, @@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): datetime, list[GraphNodeEventBase], object | None, - dict[str, VariableUnion], + dict[str, Variable], LLMUsage, ] ], @@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): item: object, flask_app: Flask, context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]: + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: + def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: + def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: parent_pool = self.graph_runtime_state.variable_pool parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index ac2870aa65..9f5818f4bb 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, Variable): + if not isinstance(original_variable, VariableBase): raise VariableOperatorNodeError("assigned variable not found") match self.node_data.write_mode: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 486e6bb6a7..5857702e72 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -2,7 +2,7 @@ import json from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # ==================== Validation Part # Check if variable exists - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=item.variable_selector) # Check if operation is supported @@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value @@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): def _handle_item( self, *, - variable: Variable, + variable: VariableBase, operation: Operation, value: Any, ): diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 85ceb9d59e..d205c6ac8f 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager -from core.variables import Segment, SegmentGroup, Variable +from core.variables import Segment, SegmentGroup, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import FileSegment, ObjectSegment -from core.variables.variables import RAGPipelineVariableInput, VariableUnion +from core.variables.variables import RAGPipelineVariableInput, Variable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, @@ -32,7 +32,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) @@ -46,13 +46,13 @@ class VariablePool(BaseModel): description="System variables", default_factory=SystemVariable.empty, ) - environment_variables: Sequence[VariableUnion] = Field( + environment_variables: Sequence[Variable] = Field( description="Environment variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) - conversation_variables: Sequence[VariableUnion] = Field( + conversation_variables: Sequence[Variable] = Field( description="Conversation variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( description="RAG pipeline variables.", @@ -105,7 +105,7 @@ class VariablePool(BaseModel): f"got {len(selector)} elements" ) - if isinstance(value, Variable): + if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) @@ -114,9 +114,9 @@ class VariablePool(BaseModel): variable = variable_factory.segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + self.variable_dictionary[node_id][name] = cast(Variable, variable) @classmethod def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index ea0bdc3537..7992785fe1 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -2,7 +2,7 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.variables import Variable +from core.variables import VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.runtime import VariablePool @@ -26,7 +26,7 @@ class VariableLoader(Protocol): """ @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: """Load variables based on the provided selectors. If the selectors are empty, this method should return an empty list. @@ -36,7 +36,7 @@ class VariableLoader(Protocol): :param: selectors: a list of string list, each inner list should have at least two elements: - the first element is the node ID, - the second element is the variable name. - :return: a list of Variable objects that match the provided selectors. + :return: a list of VariableBase objects that match the provided selectors. """ pass @@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader): Serves as a placeholder when no variable loading is needed. """ - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: return [] diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 494194369a..3f030ae127 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -38,7 +38,7 @@ from core.variables.variables import ( ObjectVariable, SecretVariable, StringVariable, - Variable, + VariableBase, ) from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, @@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = { } -def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) -def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) -def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("variable"): raise VariableError("missing variable") return mapping["variable"] -def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: +def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase: """ This factory function is used to create the environment variable or the conversation variable, not support the File type. @@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if (value := mapping.get("value")) is None: raise VariableError("missing value") - result: Variable + result: VariableBase match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return cast(Variable, result) + return cast(VariableBase, result) def build_segment(value: Any, /) -> Segment: @@ -285,8 +285,8 @@ def segment_to_variable( id: str | None = None, name: str | None = None, description: str = "", -) -> Variable: - if isinstance(segment, Variable): +) -> VariableBase: + if isinstance(segment, VariableBase): return segment name = name or selector[-1] id = id or str(uuid4()) @@ -297,7 +297,7 @@ def segment_to_variable( variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] return cast( - Variable, + VariableBase, variable_class( id=id, name=name, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index d037b0c442..2755f77f61 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields from core.helper import encrypter -from core.variables import SecretVariable, SegmentType, Variable +from core.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw): "value_type": value.value_type.value, "description": value.description, } - if isinstance(value, Variable): + if isinstance(value, VariableBase): return { "id": value.id, "name": value.name, diff --git a/api/models/workflow.py b/api/models/workflow.py index 072c6100b5..5d92da3fa1 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,11 +1,9 @@ -from __future__ import annotations - import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -46,7 +44,7 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Segment, SegmentType, Variable +from core.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper @@ -69,7 +67,7 @@ class WorkflowType(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> WorkflowType: + def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. @@ -82,7 +80,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. @@ -178,12 +176,12 @@ class Workflow(Base): # bug graph: str, features: str, created_by: str, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", - ) -> Workflow: + ) -> "Workflow": workflow = Workflow() workflow.id = str(uuid4()) workflow.tenant_id = tenant_id @@ -447,7 +445,7 @@ class Workflow(Base): # bug # decrypt secret variables value def decrypt_func( - var: Variable, + var: VariableBase, ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) @@ -463,7 +461,7 @@ class Workflow(Base): # bug return decrypted_results @environment_variables.setter - def environment_variables(self, value: Sequence[Variable]): + def environment_variables(self, value: Sequence[VariableBase]): if not value: self._environment_variables = "{}" return @@ -487,7 +485,7 @@ class Workflow(Base): # bug value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value - def encrypt_func(var: Variable) -> Variable: + def encrypt_func(var: VariableBase) -> VariableBase: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) else: @@ -517,7 +515,7 @@ class Workflow(Base): # bug return result @property - def conversation_variables(self) -> Sequence[Variable]: + def conversation_variables(self) -> Sequence[VariableBase]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: self._conversation_variables = "{}" @@ -527,7 +525,7 @@ class Workflow(Base): # bug return results @conversation_variables.setter - def conversation_variables(self, value: Sequence[Variable]): + def conversation_variables(self, value: Sequence[VariableBase]): self._conversation_variables = json.dumps( {var.name: var.model_dump() for var in value}, ensure_ascii=False, @@ -622,7 +620,7 @@ class WorkflowRun(Base): finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - pause: Mapped[WorkflowPause | None] = orm.relationship( + pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( "WorkflowPause", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", uselist=False, @@ -692,7 +690,7 @@ class WorkflowRun(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: + def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) - offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( + offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship( "WorkflowNodeExecutionOffload", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", uselist=True, @@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def preload_offload_data( - query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], + query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], ): return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) @staticmethod def preload_offload_data_and_files( - query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], + query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], ): return query.options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( @@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property @@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base): back_populates="offload_data", ) - file: Mapped[UploadFile | None] = orm.relationship( + file: Mapped[Optional["UploadFile"]] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum): INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. @@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase): ) @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: + def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, @@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase): ) return obj - def to_variable(self) -> Variable: + def to_variable(self) -> VariableBase: mapping = json.loads(self.data) return variable_factory.build_conversation_variable_from_mapping(mapping) @@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base): ) # Relationship to WorkflowDraftVariableFile - variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( + variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base): node_execution_id: str | None, description: str = "", file_id: str | None = None, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() variable.id = str(uuid4()) variable.created_at = naive_utc_now() @@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base): name: str, value: Segment, description: str = "", - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, @@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base): value: Segment, node_execution_id: str, editable: bool = False, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, @@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base): visible: bool = True, editable: bool = True, file_id: str | None = None, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=node_id, @@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base): ) # Relationship to UploadFile - upload_file: Mapped[UploadFile] = orm.relationship( + upload_file: Mapped["UploadFile"] = orm.relationship( foreign_keys=[upload_file_id], lazy="raise", uselist=False, @@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base): state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) # Relationship to WorkflowRun - workflow_run: Mapped[WorkflowRun] = orm.relationship( + workflow_run: Mapped["WorkflowRun"] = orm.relationship( foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", @@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: + def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": if isinstance(pause_reason, HumanInputRequired): return cls( type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index acc0ec2b22..92008d5ff1 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from core.variables.variables import Variable +from core.variables.variables import VariableBase from models import ConversationVariable @@ -13,7 +13,7 @@ class ConversationVariableUpdater: def __init__(self, session_maker: sessionmaker[Session]) -> None: self._session_maker: sessionmaker[Session] = session_maker - def update(self, conversation_id: str, variable: Variable) -> None: + def update(self, conversation_id: str, variable: VariableBase) -> None: stmt = select(ConversationVariable).where( ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1ba64813ba..2d8418900c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,7 +36,7 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import Variable +from core.variables.variables import VariableBase from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -270,8 +270,8 @@ class RagPipelineService: graph: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list, ) -> Workflow: """ diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 9407a2b3f0..70b0190231 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File -from core.variables import Segment, StringSegment, Variable +from core.variables import Segment, StringSegment, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import ( ArrayFileSegment, @@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader): # Application ID for which variables are being loaded. _app_id: str _tenant_id: str - _fallback_variables: Sequence[Variable] + _fallback_variables: Sequence[VariableBase] def __init__( self, engine: Engine, app_id: str, tenant_id: str, - fallback_variables: Sequence[Variable] | None = None, + fallback_variables: Sequence[VariableBase] | None = None, ): self._engine = engine self._app_id = app_id @@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader): def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: return (selector[0], selector[1]) - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: if not selectors: return [] - # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. - variable_by_selector: dict[tuple[str, str], Variable] = {} + # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance. + variable_by_selector: dict[tuple[str, str], VariableBase] = {} with Session(bind=self._engine, expire_on_commit=False) as session: srv = WorkflowDraftVariableService(session) @@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader): return list(variable_by_selector.values()) - def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]: + def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]: # This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable` # and must remain synchronized with it. # Ideally, these should be co-located for better maintainability. diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b45a167b73..d8c3159178 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory -from core.variables import Variable -from core.variables.variables import VariableUnion +from core.variables import VariableBase +from core.variables.variables import Variable from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError @@ -198,8 +198,8 @@ class WorkflowService: features: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], ) -> Workflow: """ Sync draft workflow @@ -1044,7 +1044,7 @@ def _setup_variable_pool( workflow: Workflow, node_type: NodeType, conversation_id: str, - conversation_variables: list[Variable], + conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if node_type == NodeType.START or node_type.is_trigger_node: @@ -1070,9 +1070,9 @@ def _setup_variable_pool( system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - conversation_variables=cast(list[VariableUnion], conversation_variables), # + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + conversation_variables=cast(list[Variable], conversation_variables), # ) return variable_pool diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index af4f96ba23..aa16c8af1c 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -35,7 +35,6 @@ from core.variables.variables import ( SecretVariable, StringVariable, Variable, - VariableUnion, ) from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -96,7 +95,7 @@ class _Segments(BaseModel): class _Variables(BaseModel): - variables: list[VariableUnion] + variables: list[Variable] def create_test_file( @@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad: # Create one instance of each variable type test_file = create_test_file() - all_variables: list[VariableUnion] = [ + all_variables: list[Variable] = [ NoneVariable(name="none_var"), StringVariable(value="test string", name="string_var"), IntegerVariable(value=42, name="int_var"), diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 925142892c..fb4b18b57a 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -11,7 +11,7 @@ from core.variables import ( SegmentType, StringVariable, ) -from core.variables.variables import Variable +from core.variables.variables import VariableBase def test_frozen_variables(): @@ -76,7 +76,7 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var: Variable = StringVariable(name="text", value="text") + var: VariableBase = StringVariable(name="text", value="text") assert var.to_object() == "text" var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9733bf60eb..b8869dbf1d 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -24,7 +24,7 @@ from core.variables.variables import ( IntegerVariable, ObjectVariable, StringVariable, - VariableUnion, + Variable, ) from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.runtime import VariablePool @@ -160,7 +160,7 @@ class TestVariablePoolSerialization: ) # Create environment variables with all types including ArrayFileVariable - env_vars: list[VariableUnion] = [ + env_vars: list[Variable] = [ StringVariable( id="env_string_id", name="env_string", @@ -182,7 +182,7 @@ class TestVariablePoolSerialization: ] # Create conversation variables with complex data - conv_vars: list[VariableUnion] = [ + conv_vars: list[Variable] = [ StringVariable( id="conv_string_id", name="conv_string", From 87f348a0dea279623a521aad99763d2452c41fbf Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 14 Jan 2026 09:46:41 +0800 Subject: [PATCH 22/29] feat: change param to pydantic model (#30870) --- .../console/datasets/datasets_document.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ac78d3854b..707d90f044 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -7,7 +7,7 @@ from typing import Literal, cast import sqlalchemy as sa from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel): name: str +class DocumentDatasetListParam(BaseModel): + page: int = Field(1, title="Page", description="Page number.") + limit: int = Field(20, title="Limit", description="Page size.") + search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.") + sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.") + status: str | None = Field(None, title="Status", description="Document status.") + fetch_val: str = Field("false", alias="fetch") + + register_schema_models( console_ns, KnowledgeConfig, @@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource): def get(self, dataset_id): current_user, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) - sort = request.args.get("sort", default="-created_at", type=str) - status = request.args.get("status", default=None, type=str) + raw_args = request.args.to_dict() + param = DocumentDatasetListParam.model_validate(raw_args) + page = param.page + limit = param.limit + search = param.search + sort = param.sort_by + status = param.status # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch_val = request.args.get("fetch", default="false") + fetch_val = param.fetch_val if isinstance(fetch_val, bool): fetch = fetch_val else: From e389cd1665eb3d9155497e5d39079f0e4bacc571 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 09:56:02 +0800 Subject: [PATCH 23/29] chore(deps): bump filelock from 3.20.0 to 3.20.3 in /api (#30939) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/uv.lock b/api/uv.lock index 444c7f2f5a..3a3f86390b 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1965,11 +1965,11 @@ wheels = [ [[package]] name = "filelock" -version = "3.20.0" +version = "3.20.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, + { url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" }, ] [[package]] From 7f9884e7a1d5e1892fe290e91549fa8835b594d6 Mon Sep 17 00:00:00 2001 From: UMDKyle Date: Tue, 13 Jan 2026 21:09:30 -0500 Subject: [PATCH 24/29] feat: Add option to delete or keep API keys when uninstalling plugin (#28201) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: -LAN- --- api/services/plugin/plugin_service.py | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index b8303eb724..411c335c17 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence from mimetypes import guess_type from pydantic import BaseModel +from sqlalchemy import select from yarl import URL from configs import dify_config @@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import ( from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller +from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.provider import ProviderCredential from models.provider_ids import GenericProviderID from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope @@ -506,6 +509,33 @@ class PluginService: @staticmethod def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: manager = PluginInstaller() + + # Get plugin info before uninstalling to delete associated credentials + try: + plugins = manager.list_plugins(tenant_id) + plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None) + + if plugin: + plugin_id = plugin.plugin_id + logger.info("Deleting credentials for plugin: %s", plugin_id) + + # Delete provider credentials that match this plugin + credentials = db.session.scalars( + select(ProviderCredential).where( + ProviderCredential.tenant_id == tenant_id, + ProviderCredential.provider_name.like(f"{plugin_id}/%"), + ) + ).all() + + for cred in credentials: + db.session.delete(cred) + + db.session.commit() + logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id) + except Exception as e: + logger.warning("Failed to delete credentials: %s", e) + # Continue with uninstall even if credential deletion fails + return manager.uninstall(tenant_id, plugin_installation_id) @staticmethod From e4b97fba29ea58eb9d982aa8a7fa44307873cc2e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:10:49 +0800 Subject: [PATCH 25/29] chore(deps): bump azure-core from 1.36.0 to 1.38.0 in /api (#30941) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/uv.lock b/api/uv.lock index 3a3f86390b..aacf408902 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -453,15 +453,15 @@ wheels = [ [[package]] name = "azure-core" -version = "1.36.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" }, + { url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" }, ] [[package]] From c327d0bb441c7d2a64ab8245c77f9a8113ddceb4 Mon Sep 17 00:00:00 2001 From: jialin li Date: Wed, 14 Jan 2026 10:11:30 +0800 Subject: [PATCH 26/29] fix: Correction to the full name of Volc TOS (#30741) --- api/configs/middleware/storage/volcengine_tos_storage_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index be01f2dc36..2a35300401 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings class VolcengineTOSStorageConfig(BaseSettings): """ - Configuration settings for Volcengine Tinder Object Storage (TOS) + Configuration settings for Volcengine Torch Object Storage (TOS) """ VOLCENGINE_TOS_BUCKET_NAME: str | None = Field( From 138c56bd6e4a1dfdfcc039593ac0539eccbb8779 Mon Sep 17 00:00:00 2001 From: fanadong Date: Wed, 14 Jan 2026 10:21:26 +0800 Subject: [PATCH 27/29] fix(logstore): prevent SQL injection, fix serialization issues, and optimize initialization (#30697) --- api/extensions/ext_logstore.py | 56 ++- api/extensions/logstore/aliyun_logstore.py | 88 +++- api/extensions/logstore/aliyun_logstore_pg.py | 287 +++-------- .../logstore/repositories/__init__.py | 29 ++ ..._api_workflow_node_execution_repository.py | 60 ++- .../logstore_api_workflow_run_repository.py | 152 ++++-- .../logstore_workflow_execution_repository.py | 24 +- ...tore_workflow_node_execution_repository.py | 84 +++- api/extensions/logstore/sql_escape.py | 134 +++++ .../extensions/logstore/__init__.py | 1 + .../extensions/logstore/test_sql_escape.py | 469 ++++++++++++++++++ docker/.env.example | 8 + 12 files changed, 1033 insertions(+), 359 deletions(-) create mode 100644 api/extensions/logstore/sql_escape.py create mode 100644 api/tests/unit_tests/extensions/logstore/__init__.py create mode 100644 api/tests/unit_tests/extensions/logstore/test_sql_escape.py diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py index 502f0bb46b..cda2d1ad1e 100644 --- a/api/extensions/ext_logstore.py +++ b/api/extensions/ext_logstore.py @@ -10,6 +10,7 @@ import os from dotenv import load_dotenv +from configs import dify_config from dify_app import DifyApp logger = logging.getLogger(__name__) @@ -19,12 +20,17 @@ def is_enabled() -> bool: """ Check if logstore extension is enabled. + Logstore is considered enabled when: + 1. All required Aliyun SLS environment variables are set + 2. At least one repository configuration points to a logstore implementation + Returns: - True if all required Aliyun SLS environment variables are set, False otherwise + True if logstore should be initialized, False otherwise """ # Load environment variables from .env file load_dotenv() + # Check if Aliyun SLS connection parameters are configured required_vars = [ "ALIYUN_SLS_ACCESS_KEY_ID", "ALIYUN_SLS_ACCESS_KEY_SECRET", @@ -33,24 +39,32 @@ def is_enabled() -> bool: "ALIYUN_SLS_PROJECT_NAME", ] - all_set = all(os.environ.get(var) for var in required_vars) + sls_vars_set = all(os.environ.get(var) for var in required_vars) - if not all_set: - logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set") + if not sls_vars_set: + return False - return all_set + # Check if any repository configuration points to logstore implementation + repository_configs = [ + dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY, + dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_RUN_REPOSITORY, + ] + + uses_logstore = any("logstore" in config.lower() for config in repository_configs) + + if not uses_logstore: + return False + + logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore") + return True def init_app(app: DifyApp): """ Initialize logstore on application startup. - - This function: - 1. Creates Aliyun SLS project if it doesn't exist - 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist - 3. Creates indexes with field configurations based on PostgreSQL table structures - - This operation is idempotent and only executes once during application startup. + If initialization fails, the application continues running without logstore features. Args: app: The Dify application instance @@ -58,17 +72,23 @@ def init_app(app: DifyApp): try: from extensions.logstore.aliyun_logstore import AliyunLogStore - logger.info("Initializing logstore...") + logger.info("Initializing Aliyun SLS Logstore...") - # Create logstore client and initialize project/logstores/indexes + # Create logstore client and initialize resources logstore_client = AliyunLogStore() logstore_client.init_project_logstore() - # Attach to app for potential later use app.extensions["logstore"] = logstore_client logger.info("Logstore initialized successfully") + except Exception: - logger.exception("Failed to initialize logstore") - # Don't raise - allow application to continue even if logstore init fails - # This ensures that the application can still run if logstore is misconfigured + logger.exception( + "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. " + "Application will continue but logstore features will NOT work.", + os.environ.get("ALIYUN_SLS_ENDPOINT"), + os.environ.get("ALIYUN_SLS_REGION"), + os.environ.get("ALIYUN_SLS_PROJECT_NAME"), + os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"), + ) + # Don't raise - allow application to continue even if logstore setup fails diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py index 8c64a25be4..f6a4765f14 100644 --- a/api/extensions/logstore/aliyun_logstore.py +++ b/api/extensions/logstore/aliyun_logstore.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import os +import socket import threading import time from collections.abc import Sequence @@ -179,9 +180,18 @@ class AliyunLogStore: self.region: str = os.environ.get("ALIYUN_SLS_REGION", "") self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "") self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365)) - self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + self.log_enabled: bool = ( + os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true" + ) self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true" + # Get timeout configuration + check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30)) + + # Pre-check endpoint connectivity to prevent indefinite hangs + self._check_endpoint_connectivity(self.endpoint, check_timeout) + # Initialize SDK client self.client = LogClient( self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region @@ -199,6 +209,49 @@ class AliyunLogStore: self.__class__._initialized = True + @staticmethod + def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None: + """ + Check if the SLS endpoint is reachable before creating LogClient. + Prevents indefinite hangs when the endpoint is unreachable. + + Args: + endpoint: SLS endpoint URL + timeout: Connection timeout in seconds + + Raises: + ConnectionError: If endpoint is not reachable + """ + # Parse endpoint URL to extract hostname and port + from urllib.parse import urlparse + + parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}") + hostname = parsed_url.hostname + port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80) + + if not hostname: + raise ConnectionError(f"Invalid endpoint URL: {endpoint}") + + sock = None + try: + # Create socket and set timeout + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.connect((hostname, port)) + except Exception as e: + # Catch all exceptions and provide clear error message + error_type = type(e).__name__ + raise ConnectionError( + f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}" + ) from e + finally: + # Ensure socket is properly closed + if sock: + try: + sock.close() + except Exception: # noqa: S110 + pass # Ignore errors during cleanup + @property def supports_pg_protocol(self) -> bool: """Check if PG protocol is supported and enabled.""" @@ -220,19 +273,16 @@ class AliyunLogStore: try: self._use_pg_protocol = self._pg_client.init_connection() if self._use_pg_protocol: - logger.info("Successfully connected to project %s using PG protocol", self.project_name) + logger.info("Using PG protocol for project %s", self.project_name) # Check if scan_index is enabled for all logstores self._check_and_disable_pg_if_scan_index_disabled() return True else: - logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name) + logger.info("Using SDK mode for project %s", self.project_name) return False except Exception as e: - logger.warning( - "Failed to establish PG connection for project %s: %s. Will use SDK mode.", - self.project_name, - str(e), - ) + logger.info("Using SDK mode for project %s", self.project_name) + logger.debug("PG connection details: %s", str(e)) self._use_pg_protocol = False return False @@ -246,10 +296,6 @@ class AliyunLogStore: if self._use_pg_protocol: return - logger.info( - "Attempting delayed PG connection for newly created project %s ...", - self.project_name, - ) self._attempt_pg_connection_init() self.__class__._pg_connection_timer = None @@ -284,11 +330,7 @@ class AliyunLogStore: if project_is_new: # For newly created projects, schedule delayed PG connection self._use_pg_protocol = False - logger.info( - "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.", - self.project_name, - self.__class__._pg_connection_delay, - ) + logger.info("Using SDK mode for project %s (newly created)", self.project_name) if self.__class__._pg_connection_timer is not None: self.__class__._pg_connection_timer.cancel() self.__class__._pg_connection_timer = threading.Timer( @@ -299,7 +341,6 @@ class AliyunLogStore: self.__class__._pg_connection_timer.start() else: # For existing projects, attempt PG connection immediately - logger.info("Project %s already exists. Attempting PG connection...", self.project_name) self._attempt_pg_connection_init() def _check_and_disable_pg_if_scan_index_disabled(self) -> None: @@ -318,9 +359,9 @@ class AliyunLogStore: existing_config = self.get_existing_index_config(logstore_name) if existing_config and not existing_config.scan_index: logger.info( - "Logstore %s has scan_index=false, USE SDK mode for read/write operations. " - "PG protocol requires scan_index to be enabled.", + "Logstore %s requires scan_index enabled, using SDK mode for project %s", logstore_name, + self.project_name, ) self._use_pg_protocol = False # Close PG connection if it was initialized @@ -748,7 +789,6 @@ class AliyunLogStore: reverse=reverse, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | " @@ -770,7 +810,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d", @@ -845,7 +884,6 @@ class AliyunLogStore: query=full_query, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s", @@ -853,8 +891,7 @@ class AliyunLogStore: self.project_name, from_time, to_time, - query, - sql, + full_query, ) try: @@ -865,7 +902,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py index 35aa51ce53..874c20d144 100644 --- a/api/extensions/logstore/aliyun_logstore_pg.py +++ b/api/extensions/logstore/aliyun_logstore_pg.py @@ -7,8 +7,7 @@ from contextlib import contextmanager from typing import Any import psycopg2 -import psycopg2.pool -from psycopg2 import InterfaceError, OperationalError +from sqlalchemy import create_engine from configs import dify_config @@ -16,11 +15,7 @@ logger = logging.getLogger(__name__) class AliyunLogStorePG: - """ - PostgreSQL protocol support for Aliyun SLS LogStore. - - Handles PG connection pooling and operations for regions that support PG protocol. - """ + """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool.""" def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str): """ @@ -36,24 +31,11 @@ class AliyunLogStorePG: self._access_key_secret = access_key_secret self._endpoint = endpoint self.project_name = project_name - self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None + self._engine: Any = None # SQLAlchemy Engine self._use_pg_protocol = False def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool: - """ - Check if a TCP port is reachable using socket connection. - - This provides a fast check before attempting full database connection, - preventing long waits when connecting to unsupported regions. - - Args: - host: Hostname or IP address - port: Port number - timeout: Connection timeout in seconds (default: 2.0) - - Returns: - True if port is reachable, False otherwise - """ + """Fast TCP port check to avoid long waits on unsupported regions.""" try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) @@ -65,166 +47,101 @@ class AliyunLogStorePG: return False def init_connection(self) -> bool: - """ - Initialize PostgreSQL connection pool for SLS PG protocol support. - - Attempts to connect to SLS using PostgreSQL protocol. If successful, sets - _use_pg_protocol to True and creates a connection pool. If connection fails - (region doesn't support PG protocol or other errors), returns False. - - Returns: - True if PG protocol is supported and initialized, False otherwise - """ + """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support.""" try: - # Extract hostname from endpoint (remove protocol if present) pg_host = self._endpoint.replace("http://", "").replace("https://", "") - # Get pool configuration - pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10)) + # Pool configuration + pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5)) + max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5)) + pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600)) + pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true" - logger.debug( - "Check PG protocol connection to SLS: host=%s, project=%s", - pg_host, - self.project_name, - ) + logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name) - # Fast port connectivity check before attempting full connection - # This prevents long waits when connecting to unsupported regions + # Fast port check to avoid long waits if not self._check_port_connectivity(pg_host, 5432, timeout=1.0): - logger.info( - "USE SDK mode for read/write operations, host=%s", - pg_host, - ) + logger.debug("Using SDK mode for host=%s", pg_host) return False - # Create connection pool - self._pg_pool = psycopg2.pool.SimpleConnectionPool( - minconn=1, - maxconn=pg_max_connections, - host=pg_host, - port=5432, - database=self.project_name, - user=self._access_key_id, - password=self._access_key_secret, - sslmode="require", - connect_timeout=5, - application_name=f"Dify-{dify_config.project.version}", + # Build connection URL + from urllib.parse import quote_plus + + username = quote_plus(self._access_key_id) + password = quote_plus(self._access_key_secret) + database_url = ( + f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require" ) - # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables - # Connection pool creation success already indicates connectivity + # Create SQLAlchemy engine with connection pool + self._engine = create_engine( + database_url, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + pool_timeout=30, + connect_args={ + "connect_timeout": 5, + "application_name": f"Dify-{dify_config.project.version}-fixautocommit", + "keepalives": 1, + "keepalives_idle": 60, + "keepalives_interval": 10, + "keepalives_count": 5, + }, + ) self._use_pg_protocol = True logger.info( - "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.", + "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)", self.project_name, + pool_size, + pool_recycle, ) return True except Exception as e: - # PG connection failed - fallback to SDK mode self._use_pg_protocol = False - if self._pg_pool: + if self._engine: try: - self._pg_pool.closeall() + self._engine.dispose() except Exception: - logger.debug("Failed to close PG connection pool during cleanup, ignoring") - self._pg_pool = None + logger.debug("Failed to dispose engine during cleanup, ignoring") + self._engine = None - logger.info( - "PG protocol connection failed (region may not support PG protocol): %s. " - "Falling back to SDK mode for read/write operations.", - str(e), - ) - return False - - def _is_connection_valid(self, conn: Any) -> bool: - """ - Check if a connection is still valid. - - Args: - conn: psycopg2 connection object - - Returns: - True if connection is valid, False otherwise - """ - try: - # Check if connection is closed - if conn.closed: - return False - - # Quick ping test - execute a lightweight query - # For SLS PG protocol, we can't use SELECT 1 without FROM, - # so we just check the connection status - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - cursor.fetchone() - return True - except Exception: + logger.debug("Using SDK mode for region: %s", str(e)) return False @contextmanager def _get_connection(self): - """ - Context manager to get a PostgreSQL connection from the pool. + """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically.""" + if not self._engine: + raise RuntimeError("SQLAlchemy engine is not initialized") - Automatically validates and refreshes stale connections. - - Note: Aliyun SLS PG protocol does not support transactions, so we always - use autocommit mode. - - Yields: - psycopg2 connection object - - Raises: - RuntimeError: If PG pool is not initialized - """ - if not self._pg_pool: - raise RuntimeError("PG connection pool is not initialized") - - conn = self._pg_pool.getconn() + connection = self._engine.raw_connection() try: - # Validate connection and get a fresh one if needed - if not self._is_connection_valid(conn): - logger.debug("Connection is stale, marking as bad and getting a new one") - # Mark connection as bad and get a new one - self._pg_pool.putconn(conn, close=True) - conn = self._pg_pool.getconn() - - # Aliyun SLS PG protocol does not support transactions, always use autocommit - conn.autocommit = True - yield conn + connection.autocommit = True # SLS PG protocol does not support transactions + yield connection + except Exception: + raise finally: - # Return connection to pool (or close if it's bad) - if self._is_connection_valid(conn): - self._pg_pool.putconn(conn) - else: - self._pg_pool.putconn(conn, close=True) + connection.close() def close(self) -> None: - """Close the PostgreSQL connection pool.""" - if self._pg_pool: + """Dispose SQLAlchemy engine and close all connections.""" + if self._engine: try: - self._pg_pool.closeall() - logger.info("PG connection pool closed") + self._engine.dispose() + logger.info("SQLAlchemy engine disposed") except Exception: - logger.exception("Failed to close PG connection pool") + logger.exception("Failed to dispose engine") def _is_retriable_error(self, error: Exception) -> bool: - """ - Check if an error is retriable (connection-related issues). - - Args: - error: Exception to check - - Returns: - True if the error is retriable, False otherwise - """ - # Retry on connection-related errors - if isinstance(error, (OperationalError, InterfaceError)): + """Check if error is retriable (connection-related issues).""" + # Check for psycopg2 connection errors directly + if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)): return True - # Check error message for specific connection issues error_msg = str(error).lower() retriable_patterns = [ "connection", @@ -234,34 +151,18 @@ class AliyunLogStorePG: "reset by peer", "no route to host", "network", + "operational error", + "interface error", ] return any(pattern in error_msg for pattern in retriable_patterns) def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None: - """ - Write log to SLS using PostgreSQL protocol with automatic retry. - - Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only - writes with log_version field for versioning, same as SDK implementation. - - Args: - logstore: Name of the logstore table - contents: List of (field_name, value) tuples - log_enabled: Whether to enable logging - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff).""" if not contents: return - # Extract field names and values from contents fields = [field_name for field_name, _ in contents] values = [value for _, value in contents] - - # Build INSERT statement with literal values - # Note: Aliyun SLS PG protocol doesn't support parameterized queries, - # so we need to use mogrify to safely create literal values field_list = ", ".join([f'"{field}"' for field in fields]) if log_enabled: @@ -272,67 +173,40 @@ class AliyunLogStorePG: len(contents), ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: - # Use mogrify to safely convert values to SQL literals placeholders = ", ".join(["%s"] * len(fields)) values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8") insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}' cursor.execute(insert_sql) - # Success - exit retry loop return except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., data validation error), fail immediately - logger.exception( - "Failed to put logs to logstore %s via PG protocol (non-retriable error)", - logstore, - ) + logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed - logger.exception( - "Failed to put logs to logstore %s via PG protocol after %d attempts", - logstore, - max_retries, - ) + logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries) raise def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]: - """ - Execute SQL query using PostgreSQL protocol with automatic retry. - - Args: - sql: SQL query string - logstore: Name of the logstore (for logging purposes) - log_enabled: Whether to enable logging - - Returns: - List of result rows as dictionaries - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Execute SQL query with automatic retry (3 attempts with exponential backoff).""" if log_enabled: logger.info( "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s", @@ -341,20 +215,16 @@ class AliyunLogStorePG: sql, ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(sql) - - # Get column names from cursor description columns = [desc[0] for desc in cursor.description] - # Fetch all results and convert to list of dicts result = [] for row in cursor.fetchall(): row_dict = {} @@ -372,36 +242,31 @@ class AliyunLogStorePG: return result except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., SQL syntax error), fail immediately logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s", + "Failed to execute SQL on logstore %s (non-retriable error): sql=%s", logstore, sql, ) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s", + "Failed to execute SQL on logstore %s after %d attempts: sql=%s", logstore, max_retries, sql, ) raise - # This line should never be reached due to raise above, but makes type checker happy return [] diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py index e69de29bb2..b5a4fcf844 100644 --- a/api/extensions/logstore/repositories/__init__.py +++ b/api/extensions/logstore/repositories/__init__.py @@ -0,0 +1,29 @@ +""" +LogStore repository utilities. +""" + +from typing import Any + + +def safe_float(value: Any, default: float = 0.0) -> float: + """ + Safely convert a value to float, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + +def safe_int(value: Any, default: int = 0) -> int: + """ + Safely convert a value to int, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 8c804d6bb5..f67723630b 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -14,6 +14,8 @@ from typing import Any from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value from models.workflow import WorkflowNodeExecutionModel from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.index = int(data.get("index", 0)) - model.elapsed_time = float(data.get("elapsed_time", 0)) + model.index = safe_int(data.get("index", 0)) + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) # Optional fields model.workflow_run_id = data.get("workflow_run_id") @@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep node_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_id = escape_identifier(workflow_id) + escaped_node_id = escape_identifier(node_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_id = '{workflow_id}' - AND node_id = '{node_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_id = '{escaped_workflow_id}' + AND node_id = '{escaped_node_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 @@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep else: # Use SDK with LogStore query syntax query = ( - f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}" + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}" ) from_time = 0 to_time = int(time.time()) # now @@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep workflow_run_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_run_id = escape_identifier(workflow_run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_run_id = '{workflow_run_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_run_id = '{escaped_workflow_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1000 @@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax - query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}" + query = ( + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_run_id: {escaped_workflow_run_id}" + ) from_time = 0 to_time = int(time.time()) # now @@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep """ logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id) try: + # Escape parameters to prevent SQL injection + escaped_execution_id = escape_identifier(execution_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) - tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else "" + if tenant_id: + escaped_tenant_id = escape_identifier(tenant_id) + tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'" + else: + tenant_filter = "" + sql_query = f""" SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0 + WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1 """ @@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax + # Note: Values must be quoted in LogStore query syntax to prevent injection if tenant_id: - query = f"id: {execution_id} and tenant_id: {tenant_id}" + query = ( + f"id:{escape_logstore_query_value(execution_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)}" + ) else: - query = f"id: {execution_id}" + query = f"id:{escape_logstore_query_value(execution_id)}" from_time = 0 to_time = int(time.time()) # now diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 252cdcc4df..14382ed876 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -10,6 +10,7 @@ Key Features: - Optimized deduplication using finished_at IS NOT NULL filter - Window functions only when necessary (running status queries) - Multi-tenant data isolation and security +- SQL injection prevention via parameter escaping """ import logging @@ -22,6 +23,8 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowRun @@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.total_tokens = int(data.get("total_tokens", 0)) - model.total_steps = int(data.get("total_steps", 0)) - model.exceptions_count = int(data.get("exceptions_count", 0)) + model.total_tokens = safe_int(data.get("total_tokens", 0)) + model.total_steps = safe_int(data.get("total_steps", 0)) + model.exceptions_count = safe_int(data.get("exceptions_count", 0)) # Optional fields model.graph = data.get("graph") @@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: if model.finished_at and model.created_at: model.elapsed_time = (model.finished_at - model.created_at).total_seconds() else: - model.elapsed_time = float(data.get("elapsed_time", 0)) + # Use safe conversion to handle 'null' strings and None values + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) return model @@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): status, ) # Convert triggered_from to list if needed - if isinstance(triggered_from, WorkflowRunTriggeredFrom): + if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)): triggered_from_list = [triggered_from] else: triggered_from_list = list(triggered_from) - # Build triggered_from filter - triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list]) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) - # Build status filter - status_filter = f"AND status='{status}'" if status else "" + # Build triggered_from filter with escaped values + # Support both enum and string values for triggered_from + triggered_from_filter = " OR ".join( + [ + f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'" + for tf in triggered_from_list + ] + ) + + # Build status filter with escaped value + status_filter = f"AND status='{escape_sql_string(status)}'" if status else "" # Build last_id filter for pagination # Note: This is simplified. In production, you'd need to track created_at from last record @@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' AND ({triggered_from_filter}) {status_filter} {last_id_filter} @@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id) try: + # Escape parameters to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' + AND tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}" + # Note: Values must be quoted in LogStore query syntax to prevent injection + query = ( + f"id:{escape_logstore_query_value(run_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)} " + f"and app_id:{escape_logstore_query_value(app_id)}" + ) from_time = 0 to_time = int(time.time()) # now @@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id) try: + # Escape parameter to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id}" + # Note: Values must be quoted in LogStore query syntax + query = f"id:{escape_logstore_query_value(run_id)}" from_time = 0 to_time = int(time.time()) # now @@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): triggered_from, status, ) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + # Build time range filter time_filter = "" if time_range: @@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # If status is provided, simple count if status: + escaped_status = escape_sql_string(status) + if status == "running": # Running status requires window function sql = f""" @@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' - AND status='{status}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' + AND status='{escaped_status}' AND finished_at IS NOT NULL {time_filter} """ @@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # No status filter - get counts grouped by status # Use optimized query for finished runs, separate query for running try: + # Escape parameters (already escaped above, reuse variables) # Count finished runs grouped by status finished_sql = f""" SELECT status, COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY status @@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug( "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): created_by, COUNT(DISTINCT id) AS interactions FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date, created_by diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 1119534d52..9928879a7b 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.entities import WorkflowExecution from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( @@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) -def to_serializable(obj): - """ - Convert non-JSON-serializable objects into JSON-compatible formats. - - - Uses `to_dict()` if it's a callable method. - - Falls back to string representation. - """ - if hasattr(obj, "to_dict") and callable(obj.to_dict): - return obj.to_dict() - return str(obj) - - class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): def __init__( self, @@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" # Control flag for whether to write the `graph` field to LogStore. # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; @@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Generate log_version as nanosecond timestamp for record versioning log_version = str(time.time_ns()) + # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.) + json_converter = WorkflowRuntimeTypeConverter() + logstore_model = [ ("id", domain_model.id_), ("log_version", log_version), # Add log_version field for append-only writes @@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): ("version", domain_model.workflow_version), ( "graph", - json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False) if domain_model.graph and self._enable_put_graph_field else "{}", ), ( "inputs", - json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False) if domain_model.inputs else "{}", ), ( "outputs", - json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False) if domain_model.outputs else "{}", ), diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 400a089516..4897171b12 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -24,6 +24,8 @@ from core.workflow.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier from libs.helper import extract_tenant_id from models import ( Account, @@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut node_execution_id=data.get("node_execution_id"), workflow_id=data.get("workflow_id", ""), workflow_execution_id=data.get("workflow_run_id"), - index=int(data.get("index", 0)), + index=safe_int(data.get("index", 0)), predecessor_node_id=data.get("predecessor_node_id"), node_id=data.get("node_id", ""), node_type=NodeType(data.get("node_type", "start")), @@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut outputs=outputs, status=status, error=data.get("error"), - elapsed_time=float(data.get("elapsed_time", 0.0)), + elapsed_time=safe_float(data.get("elapsed_time", 0.0)), metadata=domain_metadata, created_at=created_at, finished_at=finished_at, @@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]: logger.debug( @@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): Save or update the inputs, process_data, or outputs associated with a specific node_execution record. - For LogStore implementation, this is similar to save() since we always write - complete records. We append a new record with updated data fields. + For LogStore implementation, this is a no-op for the LogStore write because save() + already writes all fields including inputs, process_data, and outputs. The caller + typically calls save() first to persist status/metadata, then calls save_execution_data() + to persist data fields. Since LogStore writes complete records atomically, we don't + need a separate write here to avoid duplicate records. + + However, if dual-write is enabled, we still need to call the SQL repository's + save_execution_data() method to properly update the SQL database. Args: execution: The NodeExecution instance with data to save """ - logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id) - # In LogStore, we simply write a new complete record with the data - # The log_version timestamp will ensure this is treated as the latest version - self.save(execution) + logger.debug( + "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s", + execution.id, + execution.node_execution_id, + ) + # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs + # Calling save() again would create a duplicate record in the append-only LogStore + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save_execution_data(execution) + logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id) + except Exception: + logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) + # Don't raise - LogStore write succeeded, SQL is just a backup def get_by_workflow_run( self, @@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. - Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication. - This ensures we only get the final version of each node execution. + Uses LogStore SQL query with window function to get the latest version of each node execution. + This ensures we only get the most recent version of each node execution record. Args: workflow_run_id: The workflow run ID order_config: Optional configuration for ordering results @@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): A list of NodeExecution instances Note: - This method filters by finished_at IS NOT NULL to avoid duplicates from - version updates. For complete history including intermediate states, - a different query strategy would be needed. + This method uses ROW_NUMBER() window function partitioned by node_execution_id + to get the latest version (highest log_version) of each node execution. """ logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) - # Build SQL query with deduplication using finished_at IS NOT NULL - # This optimization avoids window functions for common case where we only - # want the final state of each node execution + # Build SQL query with deduplication using window function + # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) + # ensures we get the latest version of each node execution - # Build ORDER BY clause + # Escape parameters to prevent SQL injection + escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_tenant_id = escape_identifier(self._tenant_id) + + # Build ORDER BY clause for outer query order_clause = "" if order_config and order_config.order_by: order_fields = [] @@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): if order_fields: order_clause = "ORDER BY " + ", ".join(order_fields) - sql = f""" - SELECT * - FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{workflow_run_id}' - AND tenant_id='{self._tenant_id}' - AND finished_at IS NOT NULL - """ - + # Build app_id filter for subquery + app_id_filter = "" if self._app_id: - sql += f" AND app_id='{self._app_id}'" + escaped_app_id = escape_identifier(self._app_id) + app_id_filter = f" AND app_id='{escaped_app_id}'" + + # Use window function to get latest version of each node execution + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_node_execution_logstore} + WHERE workflow_run_id='{escaped_workflow_run_id}' + AND tenant_id='{escaped_tenant_id}' + {app_id_filter} + ) t + WHERE rn = 1 + """ if order_clause: sql += f" {order_clause}" diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py new file mode 100644 index 0000000000..d88d6bd959 --- /dev/null +++ b/api/extensions/logstore/sql_escape.py @@ -0,0 +1,134 @@ +""" +SQL Escape Utility for LogStore Queries + +This module provides escaping utilities to prevent injection attacks in LogStore queries. + +LogStore supports two query modes: +1. PG Protocol Mode: Uses SQL syntax with single quotes for strings +2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes + +Key Security Concerns: +- Prevent tenant A from accessing tenant B's data via injection +- SLS queries are read-only, so we focus on data access control +- Different escaping strategies for SQL vs LogStore query syntax +""" + + +def escape_sql_string(value: str) -> str: + """ + Escape a string value for safe use in SQL queries. + + This function escapes single quotes by doubling them, which is the standard + SQL escaping method. This prevents SQL injection by ensuring that user input + cannot break out of string literals. + + Args: + value: The string value to escape + + Returns: + Escaped string safe for use in SQL queries + + Examples: + >>> escape_sql_string("normal_value") + "normal_value" + >>> escape_sql_string("value' OR '1'='1") + "value'' OR ''1''=''1" + >>> escape_sql_string("tenant's_id") + "tenant''s_id" + + Security: + - Prevents breaking out of string literals + - Stops injection attacks like: ' OR '1'='1 + - Protects against cross-tenant data access + """ + if not value: + return value + + # Escape single quotes by doubling them (standard SQL escaping) + # This prevents breaking out of string literals in SQL queries + return value.replace("'", "''") + + +def escape_identifier(value: str) -> str: + """ + Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use. + + This function is for PG protocol mode (SQL syntax). + For SDK mode, use escape_logstore_query_value() instead. + + Args: + value: The identifier value to escape + + Returns: + Escaped identifier safe for use in SQL queries + + Examples: + >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000") + "550e8400-e29b-41d4-a716-446655440000" + >>> escape_identifier("tenant_id' OR '1'='1") + "tenant_id'' OR ''1''=''1" + + Security: + - Prevents SQL injection via identifiers + - Stops cross-tenant access attempts + - Works for UUIDs, alphanumeric IDs, and similar identifiers + """ + # For identifiers, use the same escaping as strings + # This is simple and effective for preventing injection + return escape_sql_string(value) + + +def escape_logstore_query_value(value: str) -> str: + """ + Escape value for LogStore query syntax (SDK mode). + + LogStore query syntax rules: + 1. Keywords (and/or/not) are case-insensitive + 2. Single quotes are ordinary characters (no special meaning) + 3. Double quotes wrap values: key:"value" + 4. Backslash is the escape character: + - \" for double quote inside value + - \\ for backslash itself + 5. Parentheses can change query structure + + To prevent injection: + - Wrap value in double quotes to treat special chars as literals + - Escape backslashes and double quotes using backslash + + Args: + value: The value to escape for LogStore query syntax + + Returns: + Quoted and escaped value safe for LogStore query syntax (includes the quotes) + + Examples: + >>> escape_logstore_query_value("normal_value") + '"normal_value"' + >>> escape_logstore_query_value("value or field:evil") + '"value or field:evil"' # 'or' and ':' are now literals + >>> escape_logstore_query_value('value"test') + '"value\\"test"' # Internal double quote escaped + >>> escape_logstore_query_value('value\\test') + '"value\\\\test"' # Backslash escaped + + Security: + - Prevents injection via and/or/not keywords + - Prevents injection via colons (:) + - Prevents injection via parentheses + - Protects against cross-tenant data access + + Note: + Escape order is critical: backslash first, then double quotes. + Otherwise, we'd double-escape the escape character itself. + """ + if not value: + return '""' + + # IMPORTANT: Escape backslashes FIRST, then double quotes + # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly) + escaped = value.replace("\\", "\\\\") # \ -> \\ + escaped = escaped.replace('"', '\\"') # " -> \" + + # Wrap in double quotes to treat as literal string + # This prevents and/or/not/:/() from being interpreted as operators + return f'"{escaped}"' diff --git a/api/tests/unit_tests/extensions/logstore/__init__.py b/api/tests/unit_tests/extensions/logstore/__init__.py new file mode 100644 index 0000000000..fe9ada9128 --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/__init__.py @@ -0,0 +1 @@ +"""LogStore extension unit tests.""" diff --git a/api/tests/unit_tests/extensions/logstore/test_sql_escape.py b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py new file mode 100644 index 0000000000..63172b3f9b --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py @@ -0,0 +1,469 @@ +""" +Unit tests for SQL escape utility functions. + +These tests ensure that SQL injection attacks are properly prevented +in LogStore queries, particularly for cross-tenant access scenarios. +""" + +import pytest + +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string + + +class TestEscapeSQLString: + """Test escape_sql_string function.""" + + def test_escape_empty_string(self): + """Test escaping empty string.""" + assert escape_sql_string("") == "" + + def test_escape_normal_string(self): + """Test escaping string without special characters.""" + assert escape_sql_string("tenant_abc123") == "tenant_abc123" + assert escape_sql_string("app-uuid-1234") == "app-uuid-1234" + + def test_escape_single_quote(self): + """Test escaping single quote.""" + # Single quote should be doubled + assert escape_sql_string("tenant'id") == "tenant''id" + assert escape_sql_string("O'Reilly") == "O''Reilly" + + def test_escape_multiple_quotes(self): + """Test escaping multiple single quotes.""" + assert escape_sql_string("a'b'c") == "a''b''c" + assert escape_sql_string("'''") == "''''''" + + # === SQL Injection Attack Scenarios === + + def test_prevent_boolean_injection(self): + """Test prevention of boolean injection attacks.""" + # Classic OR 1=1 attack + malicious_input = "tenant' OR '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR ''1''=''1" + + # When used in SQL, this becomes a safe string literal + sql = f"WHERE tenant_id='{escaped}'" + assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'" + # The entire input is now a string literal that won't match any tenant + + def test_prevent_or_injection(self): + """Test prevention of OR-based injection.""" + malicious_input = "tenant_a' OR tenant_id='tenant_b" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant_a'' OR tenant_id=''tenant_b" + + sql = f"WHERE tenant_id='{escaped}'" + # The OR is now part of the string literal, not SQL logic + assert "OR tenant_id=" in sql + # The SQL has: opening ', doubled internal quotes '', and closing ' + assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'" + + def test_prevent_union_injection(self): + """Test prevention of UNION-based injection.""" + malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1" + + # UNION becomes part of the string literal + assert "UNION" in escaped + assert escaped.count("''") == 4 # All internal quotes are doubled + + def test_prevent_comment_injection(self): + """Test prevention of comment-based injection.""" + # SQL comment to bypass remaining conditions + malicious_input = "tenant' --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' --" + + sql = f"WHERE tenant_id='{escaped}' AND deleted=false" + # The -- is now inside the string, not a SQL comment + assert "--" in sql + assert "AND deleted=false" in sql # This part is NOT commented out + + def test_prevent_semicolon_injection(self): + """Test prevention of semicolon-based multi-statement injection.""" + malicious_input = "tenant'; DROP TABLE users; --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant''; DROP TABLE users; --" + + # Semicolons and DROP are now part of the string + assert "DROP TABLE" in escaped + + def test_prevent_time_based_blind_injection(self): + """Test prevention of time-based blind SQL injection.""" + malicious_input = "tenant' AND SLEEP(5) --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' AND SLEEP(5) --" + + # SLEEP becomes part of the string + assert "SLEEP" in escaped + + def test_prevent_wildcard_injection(self): + """Test prevention of wildcard-based injection.""" + malicious_input = "tenant' OR tenant_id LIKE '%" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR tenant_id LIKE ''%" + + # The LIKE and wildcard are now part of the string + assert "LIKE" in escaped + + def test_prevent_null_byte_injection(self): + """Test handling of null bytes.""" + # Null bytes can sometimes bypass filters + malicious_input = "tenant\x00' OR '1'='1" + escaped = escape_sql_string(malicious_input) + # Null byte is preserved, but quote is escaped + assert "''1''=''1" in escaped + + # === Real-world SAAS Scenarios === + + def test_cross_tenant_access_attempt(self): + """Test prevention of cross-tenant data access.""" + # Attacker tries to access another tenant's data + attacker_input = "tenant_b' OR tenant_id='tenant_a" + escaped = escape_sql_string(attacker_input) + + sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'" + # The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a" + # which doesn't exist - preventing access to either tenant's data + assert "tenant_b'' OR tenant_id=''tenant_a" in sql + + def test_cross_app_access_attempt(self): + """Test prevention of cross-application data access.""" + attacker_input = "app1' OR app_id='app2" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE app_id='{escaped}'" + # Cannot access app2's data + assert "app1'' OR app_id=''app2" in sql + + def test_bypass_status_filter(self): + """Test prevention of bypassing status filters.""" + # Try to see all statuses instead of just 'running' + attacker_input = "running' OR status LIKE '%" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE status='{escaped}'" + # Status condition is not bypassed + assert "running'' OR status LIKE ''%" in sql + + # === Edge Cases === + + def test_escape_only_quotes(self): + """Test string with only quotes.""" + assert escape_sql_string("'") == "''" + assert escape_sql_string("''") == "''''" + + def test_escape_mixed_content(self): + """Test string with mixed quotes and other chars.""" + input_str = "It's a 'test' of O'Reilly's code" + escaped = escape_sql_string(input_str) + assert escaped == "It''s a ''test'' of O''Reilly''s code" + + def test_escape_unicode_with_quotes(self): + """Test Unicode strings with quotes.""" + input_str = "租户' OR '1'='1" + escaped = escape_sql_string(input_str) + assert escaped == "租户'' OR ''1''=''1" + + +class TestEscapeIdentifier: + """Test escape_identifier function.""" + + def test_escape_uuid(self): + """Test escaping UUID identifiers.""" + uuid = "550e8400-e29b-41d4-a716-446655440000" + assert escape_identifier(uuid) == uuid + + def test_escape_alphanumeric_id(self): + """Test escaping alphanumeric identifiers.""" + assert escape_identifier("tenant_123") == "tenant_123" + assert escape_identifier("app-abc-123") == "app-abc-123" + + def test_escape_identifier_with_quote(self): + """Test escaping identifier with single quote.""" + malicious = "tenant' OR '1'='1" + escaped = escape_identifier(malicious) + assert escaped == "tenant'' OR ''1''=''1" + + def test_identifier_injection_attempt(self): + """Test prevention of injection through identifiers.""" + # Common identifier injection patterns + test_cases = [ + ("id' OR '1'='1", "id'' OR ''1''=''1"), + ("id'; DROP TABLE", "id''; DROP TABLE"), + ("id' UNION SELECT", "id'' UNION SELECT"), + ] + + for malicious, expected in test_cases: + assert escape_identifier(malicious) == expected + + +class TestSQLInjectionIntegration: + """Integration tests simulating real SQL construction scenarios.""" + + def test_complete_where_clause_safety(self): + """Test that a complete WHERE clause is safe from injection.""" + # Simulating typical query construction + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + run_id = "run' --" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + escaped_run = escape_identifier(run_id) + + sql = f""" + SELECT * FROM workflow_runs + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND id='{escaped_run}' + """ + + # Verify all special characters are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + assert "run'' --" in sql + + # Verify SQL structure is preserved (3 conditions with AND) + assert sql.count("AND") == 2 + + def test_multiple_conditions_with_injection_attempts(self): + """Test multiple conditions all attempting injection.""" + conditions = { + "tenant_id": "t1' OR tenant_id='t2", + "app_id": "a1' OR app_id='a2", + "status": "running' OR '1'='1", + } + + where_parts = [] + for field, value in conditions.items(): + escaped = escape_sql_string(value) + where_parts.append(f"{field}='{escaped}'") + + where_clause = " AND ".join(where_parts) + + # All injection attempts are neutralized + assert "t1'' OR tenant_id=''t2" in where_clause + assert "a1'' OR app_id=''a2" in where_clause + assert "running'' OR ''1''=''1" in where_clause + + # AND structure is preserved + assert where_clause.count(" AND ") == 2 + + @pytest.mark.parametrize( + ("attack_vector", "description"), + [ + ("' OR '1'='1", "Boolean injection"), + ("' OR '1'='1' --", "Boolean with comment"), + ("' UNION SELECT * FROM users --", "Union injection"), + ("'; DROP TABLE workflow_runs; --", "Destructive command"), + ("' AND SLEEP(10) --", "Time-based blind"), + ("' OR tenant_id LIKE '%", "Wildcard injection"), + ("admin' --", "Comment bypass"), + ("' OR 1=1 LIMIT 1 --", "Limit bypass"), + ], + ) + def test_common_injection_vectors(self, attack_vector, description): + """Test protection against common injection attack vectors.""" + escaped = escape_sql_string(attack_vector) + + # Build SQL + sql = f"WHERE tenant_id='{escaped}'" + + # Verify the attack string is now a safe literal + # The key indicator: all internal single quotes are doubled + internal_quotes = escaped.count("''") + original_quotes = attack_vector.count("'") + + # Each original quote should be doubled + assert internal_quotes == original_quotes + + # Verify SQL has exactly 2 quotes (opening and closing) + assert sql.count("'") >= 2 # At least opening and closing + + def test_logstore_specific_scenario(self): + """Test SQL injection prevention in LogStore-specific scenarios.""" + # Simulate LogStore query with window function + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM workflow_execution_logstore + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + """ + + # Complex query structure is maintained + assert "ROW_NUMBER()" in sql + assert "PARTITION BY id" in sql + + # Injection attempts are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + + +# ==================================================================================== +# Tests for LogStore Query Syntax (SDK Mode) +# ==================================================================================== + + +class TestLogStoreQueryEscape: + """Test escape_logstore_query_value for SDK mode query syntax.""" + + def test_normal_value(self): + """Test escaping normal alphanumeric value.""" + value = "550e8400-e29b-41d4-a716-446655440000" + escaped = escape_logstore_query_value(value) + + # Should be wrapped in double quotes + assert escaped == '"550e8400-e29b-41d4-a716-446655440000"' + + def test_empty_value(self): + """Test escaping empty string.""" + assert escape_logstore_query_value("") == '""' + + def test_value_with_and_keyword(self): + """Test that 'and' keyword is neutralized when quoted.""" + malicious = "value and field:evil" + escaped = escape_logstore_query_value(malicious) + + # Should be wrapped in quotes, making 'and' a literal + assert escaped == '"value and field:evil"' + + # Simulate using in query + query = f"tenant_id:{escaped}" + assert query == 'tenant_id:"value and field:evil"' + + def test_value_with_or_keyword(self): + """Test that 'or' keyword is neutralized when quoted.""" + malicious = "tenant_a or tenant_id:tenant_b" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"tenant_a or tenant_id:tenant_b"' + + query = f"tenant_id:{escaped}" + assert "or" in query # Present but as literal string + + def test_value_with_not_keyword(self): + """Test that 'not' keyword is neutralized when quoted.""" + malicious = "not field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"not field:value"' + + def test_value_with_parentheses(self): + """Test that parentheses are neutralized when quoted.""" + malicious = "(tenant_a or tenant_b)" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"(tenant_a or tenant_b)"' + assert "(" in escaped # Present as literal + assert ")" in escaped # Present as literal + + def test_value_with_colon(self): + """Test that colons are neutralized when quoted.""" + malicious = "field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"field:value"' + assert ":" in escaped # Present as literal + + def test_value_with_double_quotes(self): + """Test that internal double quotes are escaped.""" + value_with_quotes = 'tenant"test"value' + escaped = escape_logstore_query_value(value_with_quotes) + + # Double quotes should be escaped with backslash + assert escaped == '"tenant\\"test\\"value"' + # Should have outer quotes plus escaped inner quotes + assert '\\"' in escaped + + def test_value_with_backslash(self): + """Test that backslashes are escaped.""" + value_with_backslash = "tenant\\test" + escaped = escape_logstore_query_value(value_with_backslash) + + # Backslash should be escaped + assert escaped == '"tenant\\\\test"' + assert "\\\\" in escaped + + def test_value_with_backslash_and_quote(self): + """Test escaping both backslash and double quote.""" + value = 'path\\to\\"file"' + escaped = escape_logstore_query_value(value) + + # Both should be escaped + assert escaped == '"path\\\\to\\\\\\"file\\""' + # Verify escape order is correct + assert "\\\\" in escaped # Escaped backslash + assert '\\"' in escaped # Escaped double quote + + def test_complex_injection_attempt(self): + """Test complex injection combining multiple operators.""" + malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")' + escaped = escape_logstore_query_value(malicious) + + # All special chars should be literals or escaped + assert escaped.startswith('"') + assert escaped.endswith('"') + # Inner double quotes escaped, operators become literals + assert "or" in escaped + assert "and" in escaped + assert '\\"' in escaped # Escaped quotes + + def test_only_backslash(self): + """Test escaping a single backslash.""" + assert escape_logstore_query_value("\\") == '"\\\\"' + + def test_only_double_quote(self): + """Test escaping a single double quote.""" + assert escape_logstore_query_value('"') == '"\\""' + + def test_multiple_backslashes(self): + """Test escaping multiple consecutive backslashes.""" + assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6 + + def test_escape_sequence_like_input(self): + """Test that existing escape sequences are properly escaped.""" + # Input looks like already escaped, but we still escape it + value = 'value\\"test' + escaped = escape_logstore_query_value(value) + # \\ -> \\\\, " -> \" + assert escaped == '"value\\\\\\"test"' + + +@pytest.mark.parametrize( + ("attack_scenario", "field", "malicious_value"), + [ + ("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"), + ("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"), + ("Boolean logic", "status", "succeeded or status:failed"), + ("Negation", "tenant_id", "not tenant_a"), + ("Field injection", "run_id", "run123 and tenant_id:evil_tenant"), + ("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"), + ("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'), + ("Backslash escape bypass", "app_id", "app\\ and app_id:evil"), + ], +) +def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str): + """Test that various LogStore query injection attempts are neutralized.""" + escaped = escape_logstore_query_value(malicious_value) + + # Build query + query = f"{field}:{escaped}" + + # All operators should be within quoted string (literals) + assert escaped.startswith('"') + assert escaped.endswith('"') + + # Verify the full query structure is safe + assert query.count(":") >= 1 # At least the main field:value separator diff --git a/docker/.env.example b/docker/.env.example index e7cb8711ce..9a3a7239c6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms # Options: # - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) # - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository # Core workflow node execution repository implementation # Options: # - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) # - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository # API workflow run repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository # API workflow node execution repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository # Workflow log cleanup configuration From 3473ff7ad1b645bddb24dafefd64e40d62cc0efe Mon Sep 17 00:00:00 2001 From: heyszt <270985384@qq.com> Date: Wed, 14 Jan 2026 10:21:46 +0800 Subject: [PATCH 28/29] fix: use Factory to create repository in Aliyun Trace (#30899) --- api/core/ops/aliyun_trace/aliyun_trace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index cf6659150f..22ad756c91 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -55,7 +55,7 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db @@ -275,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) session_factory = sessionmaker(bind=db.engine) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, From d095bd413b1cb7075d7609e4c2b47d1f34619596 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 14 Jan 2026 10:22:31 +0800 Subject: [PATCH 29/29] fix: fix LOOP_CHILDREN_Z_INDEX (#30719) --- .../workflow/hooks/use-nodes-interactions.ts | 39 ++++++++++++------- .../workflow/nodes/loop/use-interactions.ts | 11 +++--- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index afb47d5994..8277e7dac8 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1602,6 +1602,7 @@ export const useNodesInteractions = () => { const offsetX = currentPosition.x - x const offsetY = currentPosition.y - y let idMapping: Record = {} + const parentChildrenToAppend: { parentId: string, childId: string, childType: BlockEnum }[] = [] clipboardElements.forEach((nodeToPaste, index) => { const nodeType = nodeToPaste.data.type @@ -1615,6 +1616,7 @@ export const useNodesInteractions = () => { _isBundled: false, _connectedSourceHandleIds: [], _connectedTargetHandleIds: [], + _dimmed: false, title: genNewNodeTitleFromOld(nodeToPaste.data.title), }, position: { @@ -1682,27 +1684,24 @@ export const useNodesInteractions = () => { return // handle paste to nested block - if (selectedNode.data.type === BlockEnum.Iteration) { - newNode.data.isInIteration = true - newNode.data.iteration_id = selectedNode.data.iteration_id - newNode.parentId = selectedNode.id - newNode.positionAbsolute = { - x: newNode.position.x, - y: newNode.position.y, - } - // set position base on parent node - newNode.position = getNestedNodePosition(newNode, selectedNode) - } - else if (selectedNode.data.type === BlockEnum.Loop) { - newNode.data.isInLoop = true - newNode.data.loop_id = selectedNode.data.loop_id + if (selectedNode.data.type === BlockEnum.Iteration || selectedNode.data.type === BlockEnum.Loop) { + const isIteration = selectedNode.data.type === BlockEnum.Iteration + + newNode.data.isInIteration = isIteration + newNode.data.iteration_id = isIteration ? selectedNode.id : undefined + newNode.data.isInLoop = !isIteration + newNode.data.loop_id = !isIteration ? selectedNode.id : undefined + newNode.parentId = selectedNode.id + newNode.zIndex = isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX newNode.positionAbsolute = { x: newNode.position.x, y: newNode.position.y, } // set position base on parent node newNode.position = getNestedNodePosition(newNode, selectedNode) + // update parent children array like native add + parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type }) } } } @@ -1733,7 +1732,17 @@ export const useNodesInteractions = () => { } }) - setNodes([...nodes, ...nodesToPaste]) + const newNodes = produce(nodes, (draft: Node[]) => { + parentChildrenToAppend.forEach(({ parentId, childId, childType }) => { + const p = draft.find(n => n.id === parentId) + if (p) { + p.data._children?.push({ nodeId: childId, nodeType: childType }) + } + }) + draft.push(...nodesToPaste) + }) + + setNodes(newNodes) setEdges([...edges, ...edgesToPaste]) saveStateToHistory(WorkflowHistoryEvent.NodePaste, { nodeId: nodesToPaste?.[0]?.id, diff --git a/web/app/components/workflow/nodes/loop/use-interactions.ts b/web/app/components/workflow/nodes/loop/use-interactions.ts index 006d8f963b..5e8f6ae36c 100644 --- a/web/app/components/workflow/nodes/loop/use-interactions.ts +++ b/web/app/components/workflow/nodes/loop/use-interactions.ts @@ -7,6 +7,7 @@ import { useCallback } from 'react' import { useStoreApi } from 'reactflow' import { useNodesMetaData } from '@/app/components/workflow/hooks' import { + LOOP_CHILDREN_Z_INDEX, LOOP_PADDING, } from '../../constants' import { @@ -114,9 +115,7 @@ export const useNodeLoopInteractions = () => { return childrenNodes.map((child, index) => { const childNodeType = child.data.type as BlockEnum - const { - defaultValue, - } = nodesMetaDataMap![childNodeType] + const { defaultValue } = nodesMetaDataMap![childNodeType] const nodesWithSameType = nodes.filter(node => node.data.type === childNodeType) const { newNode } = generateNewNode({ type: getNodeCustomTypeByNodeDataType(childNodeType), @@ -127,15 +126,17 @@ export const useNodeLoopInteractions = () => { _isBundled: false, _connectedSourceHandleIds: [], _connectedTargetHandleIds: [], + _dimmed: false, title: nodesWithSameType.length > 0 ? `${defaultValue.title} ${nodesWithSameType.length + 1}` : defaultValue.title, + isInLoop: true, loop_id: newNodeId, - + type: childNodeType, }, position: child.position, positionAbsolute: child.positionAbsolute, parentId: newNodeId, extent: child.extent, - zIndex: child.zIndex, + zIndex: LOOP_CHILDREN_Z_INDEX, }) newNode.id = `${newNodeId}${newNode.id + index}` return newNode