diff --git a/.agents/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md index 0ed18d71d1..98a94592ab 100644 --- a/.agents/skills/component-refactoring/SKILL.md +++ b/.agents/skills/component-refactoring/SKILL.md @@ -367,7 +367,7 @@ For each extraction: ┌────────────────────────────────────────┐ │ 1. Extract code │ │ 2. Run: pnpm lint:fix │ - │ 3. Run: pnpm type-check:tsgo │ + │ 3. Run: pnpm type-check │ │ 4. Run: pnpm test │ │ 5. Test functionality manually │ │ 6. PASS? → Next extraction │ diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 4da070bdbf..105c979c58 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path: - ✅ **Import real project components** directly (including base components and siblings) - ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers -- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`) - ❌ **DO NOT mock** sibling/child components in the same directory > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. @@ -325,12 +325,12 @@ For more detailed information, refer to: ### Reference Examples in Codebase - `web/utils/classnames.spec.ts` - Utility function tests -- `web/app/components/base/button/index.spec.tsx` - Component tests +- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests - `web/__mocks__/provider-context.ts` - Mock factory example ### Project Configuration -- `web/vitest.config.ts` - Vitest configuration +- `web/vite.config.ts` - Vite/Vitest configuration - `web/vitest.setup.ts` - Test environment setup - `web/scripts/analyze-component.js` - Component analysis tool - Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files. diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 10b8fb66f9..519c3f166f 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Integration vs Mocking -- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.) - [ ] Import real project components instead of mocking - [ ] Only mock: API calls, complex context providers, third-party libs with side effects - [ ] Prefer integration testing when using single spec file @@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Mocks -- [ ] **DO NOT mock base components** (`@/app/components/base/*`) +- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`) - [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`) - [ ] Shared mock state reset in `beforeEach` - [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations @@ -127,7 +127,7 @@ For the current file being tested: - [ ] Run full directory test: `pnpm test path/to/directory/` - [ ] Check coverage report: `pnpm test:coverage` - [ ] Run `pnpm lint:fix` on all test files -- [ ] Run `pnpm type-check:tsgo` +- [ ] Run `pnpm type-check` ## Common Issues to Watch diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index f58377c4a5..8c2f1c0c58 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -2,29 +2,27 @@ ## ⚠️ Important: What NOT to Mock -### DO NOT Mock Base Components +### DO NOT Mock Base Components or dify-ui Primitives -**Never mock components from `@/app/components/base/`** such as: +**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as: -- `Loading`, `Spinner` -- `Button`, `Input`, `Select` -- `Tooltip`, `Modal`, `Dropdown` -- `Icon`, `Badge`, `Tag` +- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag` +- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast` **Why?** -- Base components will have their own dedicated tests +- These components have their own dedicated tests - Mocking them creates false positives (tests pass but real integration fails) - Using real components tests actual integration behavior ```typescript -// ❌ WRONG: Don't mock base components +// ❌ WRONG: Don't mock base components or dify-ui primitives vi.mock('@/app/components/base/loading', () => () =>
Loading
) -vi.mock('@/app/components/base/button', () => ({ children }: any) => ) +vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => })) -// ✅ CORRECT: Import and use real base components +// ✅ CORRECT: Import and use the real components import Loading from '@/app/components/base/loading' -import Button from '@/app/components/base/button' +import { Button } from '@langgenius/dify-ui/button' // They will render normally in tests ``` @@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ✅ DO -1. **Use real base components** - Import from `@/app/components/base/` directly +1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly 1. **Use real project components** - Prefer importing over mocking 1. **Use real Zustand stores** - Set test state via `store.setState()` 1. **Reset mocks in `beforeEach`**, not `afterEach` @@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ❌ DON'T -1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.) 1. **Don't mock Zustand store modules** - Use real stores with `setState()` 1. Don't mock components you can import directly 1. Don't create overly simplified mocks that miss conditional logic @@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ``` Need to use a component in test? │ -├─ Is it from @/app/components/base/*? +├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*? │ └─ YES → Import real component, DO NOT mock │ ├─ Is it a project component? diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index b92d4c35a8..7460636824 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -7,7 +7,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml deleted file mode 100644 index b0f0a36bc9..0000000000 --- a/.github/workflows/anti-slop.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Anti-Slop PR Check - -on: - pull_request_target: - types: [opened, edited, synchronize] - -permissions: - pull-requests: write - contents: read - -jobs: - anti-slop: - runs-on: ubuntu-latest - steps: - - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - close-pr: false - failure-add-pr-labels: "needs-revision" diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index fd910531db..bd47abc710 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -16,7 +16,7 @@ concurrency: jobs: api-unit: name: API Unit Tests - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 env: COVERAGE_FILE: coverage-unit defaults: @@ -35,7 +35,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -62,7 +62,7 @@ jobs: api-integration: name: API Integration Tests - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 env: COVERAGE_FILE: coverage-integration STORAGE_TYPE: opendal @@ -84,7 +84,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -105,7 +105,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -137,7 +137,7 @@ jobs: api-coverage: name: API Coverage - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 needs: - api-unit - api-integration @@ -156,7 +156,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 772ab8dd56..8a1719da3c 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -13,7 +13,7 @@ permissions: jobs: autofix: if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Complete merge group check if: github.event_name == 'merge_group' @@ -25,7 +25,7 @@ jobs: - name: Check Docker Compose inputs if: github.event_name != 'merge_group' id: docker-compose-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | docker/generate_docker_compose @@ -35,7 +35,7 @@ jobs: - name: Check web inputs if: github.event_name != 'merge_group' id: web-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** @@ -48,7 +48,7 @@ jobs: - name: Check api inputs if: github.event_name != 'merge_group' id: api-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -58,7 +58,7 @@ jobs: python-version: "3.11" - if: github.event_name != 'merge_group' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 - name: Generate Docker Compose if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' @@ -120,8 +120,7 @@ jobs: - name: ESLint autofix if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | - cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - if: github.event_name != 'merge_group' - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 5f16fc6927..2d8bde8080 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -26,6 +26,9 @@ jobs: build: runs-on: ${{ matrix.runs_on }} if: github.repository == 'langgenius/dify' + permissions: + contents: read + id-token: write strategy: matrix: include: @@ -35,28 +38,28 @@ jobs: build_context: "{{defaultContext}}:api" file: "Dockerfile" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 - service_name: "build-api-arm64" image_name_env: "DIFY_API_IMAGE_NAME" artifact_context: "api" build_context: "{{defaultContext}}:api" file: "Dockerfile" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 - service_name: "build-web-amd64" image_name_env: "DIFY_WEB_IMAGE_NAME" artifact_context: "web" build_context: "{{defaultContext}}" file: "web/Dockerfile" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 - service_name: "build-web-arm64" image_name_env: "DIFY_WEB_IMAGE_NAME" artifact_context: "web" build_context: "{{defaultContext}}" file: "web/Dockerfile" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 steps: - name: Prepare @@ -70,8 +73,8 @@ jobs: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 + - name: Set up Depot CLI + uses: depot/setup-action@v1 - name: Extract metadata for Docker id: meta @@ -81,16 +84,15 @@ jobs: - name: Build Docker image id: build - uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 + uses: depot/build-push-action@v1 with: + project: ${{ vars.DEPOT_PROJECT_ID }} context: ${{ matrix.build_context }} file: ${{ matrix.file }} platforms: ${{ matrix.platform }} build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }} labels: ${{ steps.meta.outputs.labels }} outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true - cache-from: type=gha,scope=${{ matrix.service_name }} - cache-to: type=gha,mode=max,scope=${{ matrix.service_name }} - name: Export digest env: @@ -108,9 +110,33 @@ jobs: if-no-files-found: error retention-days: 1 + fork-build-validate: + if: github.repository != 'langgenius/dify' + runs-on: ubuntu-24.04 + strategy: + matrix: + include: + - service_name: "validate-api-amd64" + build_context: "{{defaultContext}}:api" + file: "Dockerfile" + - service_name: "validate-web-amd64" + build_context: "{{defaultContext}}" + file: "web/Dockerfile" + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0 + + - name: Validate Docker image + uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0 + with: + push: false + context: ${{ matrix.build_context }} + file: ${{ matrix.file }} + platforms: linux/amd64 + create-manifest: needs: build - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: github.repository == 'langgenius/dify' strategy: matrix: diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index 5991abe3ba..65f0149a74 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -9,7 +9,7 @@ concurrency: jobs: db-migration-test-postgres: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -19,7 +19,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -40,7 +40,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -59,7 +59,7 @@ jobs: run: uv run --directory api flask upgrade-db db-migration-test-mysql: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -69,7 +69,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -94,7 +94,7 @@ jobs: sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -110,6 +110,28 @@ jobs: sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env + # hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d` + # to return (container processes started); it does not wait on healthcheck + # status. mysql:8.0's first-time init takes 15-30s, so without an explicit + # wait the migration runs while InnoDB is still initialising and gets + # killed with "Lost connection during query". Poll a real SELECT until it + # succeeds. + - name: Wait for MySQL to accept queries + run: | + set +e + for i in $(seq 1 60); do + if docker run --rm --network host mysql:8.0 \ + mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \ + -e 'SELECT 1' >/dev/null 2>&1; then + echo "MySQL ready after ${i}s" + exit 0 + fi + sleep 1 + done + echo "MySQL not ready after 60s; dumping container logs:" + docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql + exit 1 + - name: Run DB Migration env: DEBUG: true diff --git a/.github/workflows/deploy-agent-dev.yml b/.github/workflows/deploy-agent-dev.yml index cd5fe9242e..9b9b77e0a2 100644 --- a/.github/workflows/deploy-agent-dev.yml +++ b/.github/workflows/deploy-agent-dev.yml @@ -13,7 +13,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'deploy/agent-dev' diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 954537663a..c2ff8c6332 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -10,7 +10,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'deploy/dev' diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml index 9cff3a3482..2740541f0f 100644 --- a/.github/workflows/deploy-enterprise.yml +++ b/.github/workflows/deploy-enterprise.yml @@ -13,7 +13,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'deploy/enterprise' diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml index c6f1cc7e6f..0da241cf95 100644 --- a/.github/workflows/deploy-hitl.yml +++ b/.github/workflows/deploy-hitl.yml @@ -10,7 +10,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'build/feat/hitl' diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 5752076c36..b0022b863b 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -14,40 +14,69 @@ concurrency: jobs: build-docker: + if: github.event.pull_request.head.repo.full_name == github.repository runs-on: ${{ matrix.runs_on }} + permissions: + contents: read + id-token: write strategy: matrix: include: - service_name: "api-amd64" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 context: "{{defaultContext}}:api" file: "Dockerfile" - service_name: "api-arm64" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 context: "{{defaultContext}}:api" file: "Dockerfile" - service_name: "web-amd64" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 context: "{{defaultContext}}" file: "web/Dockerfile" - service_name: "web-arm64" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 context: "{{defaultContext}}" file: "web/Dockerfile" steps: - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 + - name: Set up Depot CLI + uses: depot/setup-action@v1 - name: Build Docker Image - uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 + uses: depot/build-push-action@v1 with: + project: ${{ vars.DEPOT_PROJECT_ID }} push: false context: ${{ matrix.context }} file: ${{ matrix.file }} platforms: ${{ matrix.platform }} - cache-from: type=gha - cache-to: type=gha,mode=max + + build-docker-fork: + if: github.event.pull_request.head.repo.full_name != github.repository + runs-on: ubuntu-24.04 + permissions: + contents: read + strategy: + matrix: + include: + - service_name: "api-amd64" + context: "{{defaultContext}}:api" + file: "Dockerfile" + - service_name: "web-amd64" + context: "{{defaultContext}}" + file: "web/Dockerfile" + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0 + + - name: Build Docker Image + uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0 + with: + push: false + context: ${{ matrix.context }} + file: ${{ matrix.file }} + platforms: linux/amd64 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 278e10bc04..f59cc6be48 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -7,7 +7,7 @@ jobs: permissions: contents: read pull-requests: write - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 with: diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index ba36b5c07a..278f2ed8d1 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -23,7 +23,7 @@ concurrency: jobs: pre_job: name: Skip Duplicate Checks - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 outputs: should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }} steps: @@ -39,7 +39,7 @@ jobs: name: Check Changed Files needs: pre_job if: needs.pre_job.outputs.should_skip != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 outputs: api-changed: ${{ steps.changes.outputs.api }} e2e-changed: ${{ steps.changes.outputs.e2e }} @@ -141,7 +141,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped API tests run: echo "No API-related changes detected; skipping API tests." @@ -154,7 +154,7 @@ jobs: - check-changes - api-tests-run - api-tests-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize API Tests status env: @@ -201,7 +201,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped web tests run: echo "No web-related changes detected; skipping web tests." @@ -214,7 +214,7 @@ jobs: - check-changes - web-tests-run - web-tests-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize Web Tests status env: @@ -260,7 +260,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped web full-stack e2e run: echo "No E2E-related changes detected; skipping web full-stack E2E." @@ -273,7 +273,7 @@ jobs: - check-changes - web-e2e-run - web-e2e-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize Web Full-Stack E2E status env: @@ -325,7 +325,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped VDB tests run: echo "No VDB-related changes detected; skipping VDB tests." @@ -338,7 +338,7 @@ jobs: - check-changes - vdb-tests-run - vdb-tests-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize VDB Tests status env: @@ -384,7 +384,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped DB migration tests run: echo "No migration-related changes detected; skipping DB migration tests." @@ -397,7 +397,7 @@ jobs: - check-changes - db-migration-test-run - db-migration-test-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize DB Migration Test status env: diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index eefb1ebbb9..7f82942e7e 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -12,7 +12,7 @@ permissions: {} jobs: comment: name: Comment PR with pyrefly diff - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: actions: read contents: read @@ -76,13 +76,11 @@ jobs: diff += '\\n\\n... (truncated) ...'; } - const body = diff.trim() - ? '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
' - : '### Pyrefly Diff\nNo changes detected.'; - - await github.rest.issues.createComment({ - issue_number: prNumber, - owner: context.repo.owner, - repo: context.repo.repo, - body, - }); + if (diff.trim()) { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body: '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
', + }); + } diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index ac3732579c..0cf54e3585 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -10,7 +10,7 @@ permissions: jobs: pyrefly-diff: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: contents: read issues: write @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml index 51f3ca54b6..52c16f3153 100644 --- a/.github/workflows/pyrefly-type-coverage-comment.yml +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -12,7 +12,7 @@ permissions: {} jobs: comment: name: Comment PR with type coverage - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: actions: read contents: read @@ -24,7 +24,7 @@ jobs: uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true @@ -62,7 +62,7 @@ jobs: - name: Render coverage markdown from structured data id: render run: | - comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \ + comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \ --base base_report.json \ < pr_report.json)" diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml index c795c32e31..eae8debf1a 100644 --- a/.github/workflows/pyrefly-type-coverage.yml +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -10,7 +10,7 @@ permissions: jobs: pyrefly-type-coverage: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: contents: read issues: write @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index 49d2e94695..6f3193bbf5 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -16,7 +16,7 @@ jobs: name: Validate PR title permissions: pull-requests: read - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Complete merge group check if: github.event_name == 'merge_group' diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c74f4a670a..b23648c7c6 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -12,7 +12,7 @@ on: jobs: stale: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: issues: write pull-requests: write diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c32fc9d0cb..6b00899cf0 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -15,7 +15,7 @@ permissions: jobs: python-style: name: Python Style - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -25,7 +25,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: false python-version: "3.12" @@ -57,7 +57,7 @@ jobs: web-style: name: Web Style - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 defaults: run: working-directory: ./web @@ -73,10 +73,12 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** + e2e/** + sdks/nodejs-client/** packages/** package.json pnpm-lock.yaml @@ -93,26 +95,28 @@ jobs: - name: Restore ESLint cache if: steps.changed-files.outputs.any_changed == 'true' id: eslint-cache-restore - uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache - key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + path: .eslintcache + key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} restore-keys: | - ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web + env: + NODE_OPTIONS: --max-old-space-size=4096 run: vp run lint:tss - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run type-check - name: Web dead code check @@ -122,14 +126,14 @@ jobs: - name: Save ESLint cache if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache + path: .eslintcache key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} superlinter: name: SuperLinter - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -140,7 +144,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | **.sh diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index 467f31fccf..79fddb1853 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -18,7 +18,7 @@ concurrency: jobs: build: name: unit test for Node.js SDK - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 defaults: run: @@ -30,7 +30,7 @@ jobs: persist-credentials: false - name: Use Node.js - uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 with: node-version: 22 cache: '' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 541200293d..5f48c22c56 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -35,7 +35,7 @@ concurrency: jobs: translate: if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 timeout-minutes: 120 steps: @@ -158,7 +158,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.context.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93 + uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml index 790ea9126d..87c88e2023 100644 --- a/.github/workflows/trigger-i18n-sync.yml +++ b/.github/workflows/trigger-i18n-sync.yml @@ -16,7 +16,7 @@ concurrency: jobs: trigger: if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 timeout-minutes: 5 steps: diff --git a/.github/workflows/vdb-tests-full.yml b/.github/workflows/vdb-tests-full.yml index f0def8fe7a..5c241af5c5 100644 --- a/.github/workflows/vdb-tests-full.yml +++ b/.github/workflows/vdb-tests-full.yml @@ -16,7 +16,7 @@ jobs: test: name: Full VDB Tests if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 strategy: matrix: python-version: @@ -36,7 +36,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -65,7 +65,7 @@ jobs: # tiflash - name: Set up Full Vector Store Matrix - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index f3966f15b9..38ec96f00f 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -13,7 +13,7 @@ concurrency: jobs: test: name: VDB Smoke Tests - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 strategy: matrix: python-version: @@ -33,7 +33,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -62,7 +62,7 @@ jobs: # tiflash - name: Set up Vector Stores for Smoke Coverage - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml index 10dc31bde8..bdc24887db 100644 --- a/.github/workflows/web-e2e.yml +++ b/.github/workflows/web-e2e.yml @@ -13,7 +13,7 @@ concurrency: jobs: test: name: Web Full-Stack E2E - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04-4 defaults: run: shell: bash @@ -28,7 +28,7 @@ jobs: uses: ./.github/actions/setup-web - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index f3ab4c62c7..4619f3c104 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -16,7 +16,7 @@ concurrency: jobs: test: name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }}) - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04-4 env: VITEST_COVERAGE_SCOPE: app-components strategy: @@ -54,7 +54,7 @@ jobs: name: Merge Test Reports if: ${{ !cancelled() }} needs: [test] - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04-4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} defaults: @@ -89,3 +89,37 @@ jobs: flags: web env: CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} + + dify-ui-test: + name: dify-ui Tests + runs-on: depot-ubuntu-24.04-4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + defaults: + run: + shell: bash + working-directory: ./packages/dify-ui + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web environment + uses: ./.github/actions/setup-web + + - name: Install Chromium for Browser Mode + run: vp exec playwright install --with-deps chromium + + - name: Run dify-ui tests + run: vp test run --coverage --silent=passed-only + + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 + with: + directory: packages/dify-ui/coverage + flags: dify-ui + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 53dea88899..836bddbb49 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info .vscode/* !.vscode/launch.json.template +!.vscode/settings.example.json !.vscode/README.md api/.vscode # vscode Code History Extension @@ -236,9 +237,15 @@ scripts/stress-test/reports/ .playwright-mcp/ .serena/ +# vitest browser mode attachments (failure screenshots, traces, etc.) +.vitest-attachments/ +**/__screenshots__/ + # settings *.local.json *.local.md # Code Agent Folder .qoder/* + +.eslintcache diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit index 13bbd81cf6..d48381bce2 100755 --- a/.vite-hooks/pre-commit +++ b/.vite-hooks/pre-commit @@ -56,44 +56,9 @@ if $api_modified; then fi fi -if $web_modified; then - if $skip_web_checks; then - echo "Git operation in progress, skipping web checks" - exit 0 - fi - - echo "Running ESLint on web module" - - if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then - web_ts_modified=false - else - ts_diff_status=$? - if [ $ts_diff_status -eq 1 ]; then - web_ts_modified=true - else - echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)." - exit $ts_diff_status - fi - fi - - cd ./web || exit 1 - vp staged - - if $web_ts_modified; then - echo "Running TypeScript type-check:tsgo" - if ! npm run type-check:tsgo; then - echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors." - exit 1 - fi - else - echo "No staged TypeScript changes detected, skipping type-check:tsgo" - fi - - echo "Running knip" - if ! npm run knip; then - echo "Knip check failed. Please run 'npm run knip' to fix the errors." - exit 1 - fi - - cd ../ +if $skip_web_checks; then + echo "Git operation in progress, skipping web checks" + exit 0 fi + +vp staged diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index c3e2c50c52..2611b75c6c 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -2,21 +2,10 @@ "version": "0.2.0", "configurations": [ { - "name": "Python: Flask API", + "name": "Python: API (gevent)", "type": "debugpy", "request": "launch", - "module": "flask", - "env": { - "FLASK_APP": "app.py", - "FLASK_ENV": "development" - }, - "args": [ - "run", - "--host=0.0.0.0", - "--port=5001", - "--no-debugger", - "--no-reload" - ], + "program": "${workspaceFolder}/api/app.py", "jinja": true, "justMyCode": true, "cwd": "${workspaceFolder}/api", diff --git a/web/.vscode/settings.example.json b/.vscode/settings.example.json similarity index 86% rename from web/.vscode/settings.example.json rename to .vscode/settings.example.json index 4b356f5b7a..7cdbc51a3b 100644 --- a/web/.vscode/settings.example.json +++ b/.vscode/settings.example.json @@ -1,12 +1,16 @@ { - // Disable the default formatter, use eslint instead - "prettier.enable": false, - "editor.formatOnSave": false, + "cucumber.features": [ + "e2e/features/**/*.feature", + ], + "cucumber.glue": [ + "e2e/features/**/*.ts", + ], + + "tailwindCSS.experimental.configFile": "web/app/styles/globals.css", // Auto fix "editor.codeActionsOnSave": { "source.fixAll.eslint": "explicit", - "source.organizeImports": "never" }, // Silent the stylistic rules in your IDE, but still auto fix them diff --git a/AGENTS.md b/AGENTS.md index d25d2eed96..8be2daef95 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -30,7 +30,7 @@ The codebase is split into: ## Language Style - **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation. -- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types. +- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types. ## General Practices diff --git a/README.md b/README.md index d9848a6c78..778028fc76 100644 --- a/README.md +++ b/README.md @@ -139,19 +139,6 @@ Star Dify on GitHub and be instantly notified of new releases. If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). -#### Customizing Suggested Questions - -You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions: - -```bash -# In your .env file -SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]' -SUGGESTED_QUESTIONS_MAX_TOKENS=512 -SUGGESTED_QUESTIONS_TEMPERATURE=0.3 -``` - -See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions. - ### Metrics Monitoring with Grafana Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more. @@ -160,7 +147,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source ### Deployment with Kubernetes -If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. +If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) diff --git a/api/.env.example b/api/.env.example index a04a18944a..f6f65011ea 100644 --- a/api/.env.example +++ b/api/.env.example @@ -33,6 +33,9 @@ TRIGGER_URL=http://localhost:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Collaboration mode toggle +ENABLE_COLLABORATION_MODE=false + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 @@ -57,6 +60,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. +# Leave empty to preserve current unprefixed behavior. +REDIS_KEY_PREFIX= # redis Sentinel configuration. REDIS_USE_SENTINEL=false @@ -653,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y MARKETPLACE_ENABLED=true MARKETPLACE_API_URL=https://marketplace.dify.ai +# Creators Platform configuration +CREATORS_PLATFORM_FEATURES_ENABLED=true +CREATORS_PLATFORM_API_URL=https://creators.dify.ai +CREATORS_PLATFORM_OAUTH_CLIENT_ID= + # Endpoint configuration ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} @@ -703,22 +714,6 @@ SWAGGER_UI_PATH=/swagger-ui.html # Set to false to export dataset IDs as plain text for easier cross-environment import DSL_EXPORT_ENCRYPT_DATASET_ID=true -# Suggested Questions After Answer Configuration -# These environment variables allow customization of the suggested questions feature -# -# Custom prompt for generating suggested questions (optional) -# If not set, uses the default prompt that generates 3 questions under 20 characters each -# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]" -# SUGGESTED_QUESTIONS_PROMPT= - -# Maximum number of tokens for suggested questions generation (default: 256) -# Adjust this value for longer questions or more questions -# SUGGESTED_QUESTIONS_MAX_TOKENS=256 - -# Temperature for suggested questions generation (default: 0.0) -# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions -# SUGGESTED_QUESTIONS_TEMPERATURE=0 - # Tenant isolated task queue configuration TENANT_ISOLATED_TASK_CONCURRENCY=1 diff --git a/api/.ruff.toml b/api/.ruff.toml index bd9684ef65..dd78024a02 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -106,3 +106,6 @@ msg = "Use Pydantic payload/query models instead of reqparse." [lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.isort] +known-first-party = ["graphon"] \ No newline at end of file diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example index 6bdfa2c039..1001559176 100644 --- a/api/.vscode/launch.json.example +++ b/api/.vscode/launch.json.example @@ -3,29 +3,21 @@ "compounds": [ { "name": "Launch Flask and Celery", - "configurations": ["Python: Flask", "Python: Celery"] + "configurations": ["Python: API (gevent)", "Python: Celery"] } ], "configurations": [ { - "name": "Python: Flask", - "consoleName": "Flask", + "name": "Python: API (gevent)", + "consoleName": "API", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", "cwd": "${workspaceFolder}", "envFile": ".env", - "module": "flask", + "program": "${workspaceFolder}/app.py", "justMyCode": true, - "jinja": true, - "env": { - "FLASK_APP": "app.py", - "GEVENT_SUPPORT": "True" - }, - "args": [ - "run", - "--port=5001" - ] + "jinja": true }, { "name": "Python: Celery", diff --git a/api/README.md b/api/README.md index 00562f3f78..a075bc0fa9 100644 --- a/api/README.md +++ b/api/README.md @@ -101,3 +101,11 @@ The scripts resolve paths relative to their location, so you can run them from a uv run ruff format ./ # Format code uv run basedpyright . # Type checking ``` + +## Generate TS stub + +``` +uv run dev/generate_swagger_specs.py --output-dir openapi +``` + +use https://jsontotable.org/openapi-to-typescript to convert to typescript diff --git a/api/app.py b/api/app.py index c018c8a045..e53b037be5 100644 --- a/api/app.py +++ b/api/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import sys from typing import TYPE_CHECKING, cast @@ -9,17 +10,35 @@ if TYPE_CHECKING: celery: Celery +HOST = "0.0.0.0" +PORT = 5001 +logger = logging.getLogger(__name__) + + def is_db_command() -> bool: if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": return True return False +def log_startup_banner(host: str, port: int) -> None: + debugger_attached = sys.gettrace() is not None + logger.info("Serving Dify API via gevent WebSocket server") + logger.info("Bound to http://%s:%s", host, port) + logger.info("Debugger attached: %s", "on" if debugger_attached else "off") + logger.info("Press CTRL+C to quit") + + # create app +flask_app = None +socketio_app = None + if is_db_command(): from app_factory import create_migrations_app app = create_migrations_app() + socketio_app = app + flask_app = app else: # Gunicorn and Celery handle monkey patching automatically in production by # specifying the `gevent` worker class. Manual monkey patching is not required here. @@ -30,8 +49,14 @@ else: from app_factory import create_app - app = create_app() + socketio_app, flask_app = create_app() + app = flask_app celery = cast("Celery", app.extensions["celery"]) if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001) + from gevent import pywsgi + from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs] + + log_startup_banner(HOST, PORT) + server = pywsgi.WSGIServer((HOST, PORT), socketio_app, handler_class=WebSocketHandler) + server.serve_forever() diff --git a/api/app_factory.py b/api/app_factory.py index 76838f9925..48e50ceae9 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,6 +1,7 @@ import logging import time +import socketio # type: ignore[reportMissingTypeStubs] from flask import request from opentelemetry.trace import get_current_span from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID @@ -10,6 +11,7 @@ from contexts.wrapper import RecyclableContextVar from controllers.console.error import UnauthorizedAndForceLogout from core.logging.context import init_request_context from dify_app import DifyApp +from extensions.ext_socketio import sio from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import LicenseStatus @@ -122,14 +124,18 @@ def create_flask_app_with_configs() -> DifyApp: return dify_app -def create_app() -> DifyApp: +def create_app() -> tuple[socketio.WSGIApp, DifyApp]: start_time = time.perf_counter() app = create_flask_app_with_configs() initialize_extensions(app) + + sio.app = app + socketio_app = socketio.WSGIApp(sio, app) + end_time = time.perf_counter() if dify_config.DEBUG: logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) - return app + return socketio_app, app def initialize_extensions(app: DifyApp): diff --git a/api/commands/account.py b/api/commands/account.py index 6a2a2e0428..761323a73d 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,6 +2,7 @@ import base64 import secrets import click +from sqlalchemy.orm import Session from constants.languages import languages from extensions.ext_database import db @@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm): # encrypt password with salt password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = db.session.merge(account) - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + session.commit() AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm): click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account = db.session.merge(account) - account.email = normalized_new_email - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.email = normalized_new_email + session.commit() click.echo(click.style("Email updated successfully.", fg="green")) diff --git a/api/commands/plugin.py b/api/commands/plugin.py index c34391025a..8bd5392d7b 100644 --- a/api/commands/plugin.py +++ b/api/commands/plugin.py @@ -11,7 +11,7 @@ from configs import dify_config from core.helper import encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.plugin import PluginInstaller -from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params +from core.tools.utils.system_encryption import encrypt_system_params from extensions.ext_database import db from models import Tenant from models.oauth import DatasourceOauthParamConfig, DatasourceProvider @@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) @@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d37cff63e9..52e33c1789 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings): ) +class CreatorsPlatformConfig(BaseSettings): + """ + Configuration for Creators Platform integration + """ + + CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field( + description="Enable or disable Creators Platform features", + default=True, + ) + + CREATORS_PLATFORM_API_URL: HttpUrl = Field( + description="Creators Platform API URL", + default=HttpUrl("https://creators.dify.ai"), + ) + + CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field( + description="OAuth client ID for Creators Platform integration", + default="", + ) + + class EndpointConfig(BaseSettings): """ Configuration for various application endpoints and URLs @@ -1274,6 +1295,13 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class CollaborationConfig(BaseSettings): + ENABLE_COLLABORATION_MODE: bool = Field( + description="Whether to enable collaboration mode features across the workspace", + default=False, + ) + + class LoginConfig(BaseSettings): ENABLE_EMAIL_CODE_LOGIN: bool = Field( description="whether to enable email code login", @@ -1372,6 +1400,7 @@ class FeatureConfig( AuthConfig, # Changed from OAuthConfig to AuthConfig BillingConfig, CodeExecutionSandboxConfig, + CreatorsPlatformConfig, TriggerConfig, AsyncWorkflowConfig, PluginConfig, @@ -1399,6 +1428,7 @@ class FeatureConfig( WorkflowConfig, WorkflowNodeExecutionConfig, WorkspaceConfig, + CollaborationConfig, LoginConfig, AccountConfig, SwaggerUIConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 817284d26f..c392b8840f 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -160,6 +160,16 @@ class DatabaseConfig(BaseSettings): default="", ) + DB_SESSION_TIMEZONE_OVERRIDE: str = Field( + description=( + "PostgreSQL session timezone override injected via startup options." + " Default is 'UTC' for out-of-the-box consistency." + " Set to empty string to disable app-level timezone injection, for example when using RDS Proxy" + " together with a database-side default timezone." + ), + default="UTC", + ) + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str: @@ -227,12 +237,13 @@ class DatabaseConfig(BaseSettings): connect_args: dict[str, str] = {} # Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"): - timezone_opt = "-c timezone=UTC" - if options: - merged_options = f"{options} {timezone_opt}" - else: - merged_options = timezone_opt - connect_args = {"options": merged_options} + merged_options = options.strip() + session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip() + if session_timezone_override: + timezone_opt = f"-c timezone={session_timezone_override}" + merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt + if merged_options: + connect_args = {"options": merged_options} result: SQLAlchemyEngineOptionsDict = { "pool_size": self.SQLALCHEMY_POOL_SIZE, diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index b49275758a..2def0a0d4e 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -32,6 +32,11 @@ class RedisConfig(BaseSettings): default=0, ) + REDIS_KEY_PREFIX: str = Field( + description="Optional global prefix for Redis keys, topics, and transport artifacts", + default="", + ) + REDIS_USE_SSL: bool = Field( description="Enable SSL/TLS for the Redis connection", default=False, diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py new file mode 100644 index 0000000000..b0fbe0075c --- /dev/null +++ b/api/constants/dsl_version.py @@ -0,0 +1 @@ +CURRENT_APP_DSL_VERSION = "0.6.0" diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 4fe3fc9062..8e665c1386 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any -from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field +from graphon.file import helpers as file_helpers from models.model import IconType type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/common/human_input.py b/api/controllers/common/human_input.py new file mode 100644 index 0000000000..5d6f4efb95 --- /dev/null +++ b/api/controllers/common/human_input.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, JsonValue + + +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict[str, JsonValue] + action: str diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index d624b10b22..980e828945 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -65,6 +65,7 @@ from .app import ( statistic, workflow, workflow_app_log, + workflow_comment, workflow_draft_variable, workflow_run, workflow_statistic, @@ -116,6 +117,7 @@ from .explore import ( saved_message, trial, ) +from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport] # Import tag controllers from .tag import tags @@ -201,6 +203,7 @@ __all__ = [ "saved_message", "setup", "site", + "socketio_workflow", "spec", "statistic", "tags", @@ -211,6 +214,7 @@ __all__ = [ "website", "workflow", "workflow_app_log", + "workflow_comment", "workflow_draft_variable", "workflow_run", "workflow_statistic", diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2018f60215..a736fc8bc8 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,11 +5,9 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.enums import WorkflowExecutionStatus -from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from controllers.common.helpers import FileInfo @@ -31,13 +29,15 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db from fields.base import ResponseModel +from graphon.enums import WorkflowExecutionStatus +from libs.helper import build_icon_url from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType from services.app_dsl_service import AppDslService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService -from services.entities.dsl_entities import ImportMode +from services.entities.dsl_entities import ImportMode, ImportStatus from services.entities.knowledge_entities.knowledge_entities import ( DataSource, InfoList, @@ -129,6 +129,7 @@ class AppNamePayload(BaseModel): class AppIconPayload(BaseModel): icon: str | None = Field(default=None, description="Icon data") + icon_type: IconType | None = Field(default=None, description="Icon type") icon_background: str | None = Field(default=None, description="Icon background color") @@ -161,15 +162,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None: return value -def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: - if icon is None or icon_type is None: - return None - icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) - if icon_type_value.lower() != IconType.IMAGE: - return None - return file_helpers.get_signed_file_url(icon) - - class Tag(ResponseModel): id: str name: str @@ -292,7 +284,7 @@ class Site(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("icon_type", mode="before") @classmethod @@ -342,7 +334,7 @@ class AppPartial(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("created_at", "updated_at", mode="before") @classmethod @@ -390,7 +382,7 @@ class AppDetailWithSite(AppDetail): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) class AppPagination(ResponseModel): @@ -632,7 +624,7 @@ class AppCopyApi(Resource): args = CopyAppPayload.model_validate(console_ns.payload or {}) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) result = import_service.import_app( @@ -645,6 +637,13 @@ class AppCopyApi(Resource): icon=args.icon, icon_background=args.icon_background, ) + if result.status == ImportStatus.FAILED: + session.rollback() + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + session.rollback() + return result.model_dump(mode="json"), 202 + session.commit() # Inherit web app permission from original app if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: @@ -693,6 +692,32 @@ class AppExportApi(Resource): return payload.model_dump(mode="json") +@console_ns.route("/apps//publish-to-creators-platform") +class AppPublishToCreatorsPlatformApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=None) + @edit_permission_required + def post(self, app_model): + """Publish app to Creators Platform""" + from configs import dify_config + from core.helper.creators import get_redirect_url, upload_dsl + + if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + return {"error": "Creators Platform features are not enabled"}, 403 + + current_user, _ = current_account_with_tenant() + + dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False) + dsl_bytes = dsl_content.encode("utf-8") + + claim_code = upload_dsl(dsl_bytes) + redirect_url = get_redirect_url(str(current_user.id), claim_code) + + return {"redirect_url": redirect_url} + + @console_ns.route("/apps//name") class AppNameApi(Resource): @console_ns.doc("check_app_name") @@ -731,7 +756,12 @@ class AppIconApi(Resource): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") + app_model = app_service.update_app_icon( + app_model, + args.icon or "", + args.icon_background or "", + args.icon_type, + ) response_model = AppDetail.model_validate(app_model, from_attributes=True) return response_model.model_dump(mode="json") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 80bd7d1d8d..e91dc9cfe5 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,6 +1,6 @@ from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import register_schema_models from controllers.console.app.wraps import get_app_model @@ -52,8 +52,9 @@ class AppImportApi(Resource): current_user, _ = current_account_with_tenant() args = AppImportPayload.model_validate(console_ns.payload) - # Create service with session - with sessionmaker(db.engine).begin() as session: + # AppDslService performs internal commits for some creation paths, so use a plain + # Session here instead of nesting it inside sessionmaker(...).begin(). + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Import app account = current_user @@ -69,6 +70,10 @@ class AppImportApi(Resource): icon_background=args.icon_background, app_id=args.app_id, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") @@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource): # Check user role first current_user, _ = current_account_with_tenant() - # Create service with session - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource): @account_initialization_required @edit_permission_required def get(self, app_model: App): - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 78ddb904e1..91fbe4a85a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import Resource, fields -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -23,6 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index d83925d173..fe274e4c9a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,7 +3,6 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -27,6 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index d329d22309..b2b1049f0c 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -2,20 +2,37 @@ from typing import Literal import sqlalchemy as sa from flask import abort, request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, or_ from sqlalchemy.orm import selectinload from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.raws import FilesContainedField +from fields.conversation_fields import ( + Conversation as ConversationResponse, +) +from fields.conversation_fields import ( + ConversationDetail as ConversationDetailResponse, +) +from fields.conversation_fields import ( + ConversationMessageDetail as ConversationMessageDetailResponse, +) +from fields.conversation_fields import ( + ConversationPagination as ConversationPaginationResponse, +) +from fields.conversation_fields import ( + ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse, +) +from fields.conversation_fields import ( + ResultResponse, +) from libs.datetime_utils import naive_utc_now, parse_time_range -from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models import Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode @@ -62,267 +79,16 @@ console_ns.schema_model( ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model( - "SimpleAccount", - { - "id": fields.String, - "name": fields.String, - "email": fields.String, - }, -) - -feedback_stat_model = console_ns.model( - "FeedbackStat", - { - "like": fields.Integer, - "dislike": fields.Integer, - }, -) - -status_count_model = console_ns.model( - "StatusCount", - { - "success": fields.Integer, - "failed": fields.Integer, - "partial_success": fields.Integer, - "paused": fields.Integer, - }, -) - -message_file_model = console_ns.model( - "MessageFile", - { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), - }, -) - -agent_thought_model = console_ns.model( - "AgentThought", - { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), - }, -) - -simple_model_config_model = console_ns.model( - "SimpleModelConfig", - { - "model": fields.Raw(attribute="model_dict"), - "pre_prompt": fields.String, - }, -) - -model_config_model = console_ns.model( - "ModelConfig", - { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "model": fields.Raw, - "user_input_form": fields.Raw, - "pre_prompt": fields.String, - "agent_mode": fields.Raw, - }, -) - -# Models that depend on simple_account_model -feedback_model = console_ns.model( - "Feedback", - { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_model, allow_null=True), - }, -) - -annotation_model = console_ns.model( - "Annotation", - { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -annotation_hit_history_model = console_ns.model( - "AnnotationHitHistory", - { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - - -class MessageTextField(fields.Raw): - def format(self, value): - return value[0]["text"] if value else "" - - -# Simple message detail model -simple_message_detail_model = console_ns.model( - "SimpleMessageDetail", - { - "inputs": FilesContainedField, - "query": fields.String, - "message": MessageTextField, - "answer": fields.String, - }, -) - -# Message detail model that depends on multiple models -message_detail_model = console_ns.model( - "MessageDetail", - { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_model)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_model, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "message_files": fields.List(fields.Nested(message_file_model)), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - }, -) - -# Conversation models -conversation_fields_model = console_ns.model( - "Conversation", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String(), - "from_account_id": fields.String, - "from_account_name": fields.String, - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotation": fields.Nested(annotation_model, allow_null=True), - "model_config": fields.Nested(simple_model_config_model), - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - "message": fields.Nested(simple_message_detail_model, attribute="first_message"), - }, -) - -conversation_pagination_model = console_ns.model( - "ConversationPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_fields_model), attribute="items"), - }, -) - -conversation_message_detail_model = console_ns.model( - "ConversationMessageDetail", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "model_config": fields.Nested(model_config_model), - "message": fields.Nested(message_detail_model, attribute="first_message"), - }, -) - -conversation_with_summary_model = console_ns.model( - "ConversationWithSummary", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String, - "from_account_id": fields.String, - "from_account_name": fields.String, - "name": fields.String, - "summary": fields.String(attribute="summary_or_query"), - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "model_config": fields.Nested(simple_model_config_model), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - "status_count": fields.Nested(status_count_model), - }, -) - -conversation_with_summary_pagination_model = console_ns.model( - "ConversationWithSummaryPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"), - }, -) - -conversation_detail_model = console_ns.model( - "ConversationDetail", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "introduction": fields.String, - "model_config": fields.Nested(model_config_model), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - }, +register_schema_models( + console_ns, + CompletionConversationQuery, + ChatConversationQuery, + ConversationResponse, + ConversationPaginationResponse, + ConversationMessageDetailResponse, + ConversationWithSummaryPaginationResponse, + ConversationDetailResponse, + ResultResponse, ) @@ -332,13 +98,12 @@ class CompletionConversationApi(Resource): @console_ns.doc(description="Get completion conversations with pagination and filtering") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CompletionConversationQuery.__name__]) - @console_ns.response(200, "Success", conversation_pagination_model) + @console_ns.response(200, "Success", console_ns.models[ConversationPaginationResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - @marshal_with(conversation_pagination_model) @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() @@ -394,7 +159,9 @@ class CompletionConversationApi(Resource): conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) - return conversations + return ConversationPaginationResponse.model_validate(conversations, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//completion-conversations/") @@ -402,19 +169,19 @@ class CompletionConversationDetailApi(Resource): @console_ns.doc("get_completion_conversation") @console_ns.doc(description="Get completion conversation details with messages") @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) - @console_ns.response(200, "Success", conversation_message_detail_model) + @console_ns.response(200, "Success", console_ns.models[ConversationMessageDetailResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - @marshal_with(conversation_message_detail_model) @edit_permission_required def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - - return _get_conversation(app_model, conversation_id) + return ConversationMessageDetailResponse.model_validate( + _get_conversation(app_model, conversation_id), from_attributes=True + ).model_dump(mode="json") @console_ns.doc("delete_completion_conversation") @console_ns.doc(description="Delete a completion conversation") @@ -436,7 +203,7 @@ class CompletionConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 @console_ns.route("/apps//chat-conversations") @@ -445,13 +212,12 @@ class ChatConversationApi(Resource): @console_ns.doc(description="Get chat conversations with pagination, filtering and summary") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ChatConversationQuery.__name__]) - @console_ns.response(200, "Success", conversation_with_summary_pagination_model) + @console_ns.response(200, "Success", console_ns.models[ConversationWithSummaryPaginationResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(conversation_with_summary_pagination_model) @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() @@ -546,7 +312,9 @@ class ChatConversationApi(Resource): conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) - return conversations + return ConversationWithSummaryPaginationResponse.model_validate(conversations, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//chat-conversations/") @@ -554,19 +322,19 @@ class ChatConversationDetailApi(Resource): @console_ns.doc("get_chat_conversation") @console_ns.doc(description="Get chat conversation details") @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) - @console_ns.response(200, "Success", conversation_detail_model) + @console_ns.response(200, "Success", console_ns.models[ConversationDetailResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(conversation_detail_model) @edit_permission_required def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - - return _get_conversation(app_model, conversation_id) + return ConversationDetailResponse.model_validate( + _get_conversation(app_model, conversation_id), from_attributes=True + ).model_dump(mode="json") @console_ns.doc("delete_chat_conversation") @console_ns.doc(description="Delete a chat conversation") @@ -588,7 +356,7 @@ class ChatConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 def _get_conversation(app_model, conversation_id): diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 369c26a80c..9c8b095b9f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,44 +1,86 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.conversation_variable_fields import ( - conversation_variable_fields, - paginated_conversation_variable_fields, -) +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from libs.login import login_required from models import ConversationVariable from models.model import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ConversationVariablesQuery(BaseModel): conversation_id: str = Field(..., description="Conversation ID to filter variables") -console_ns.schema_model( - ConversationVariablesQuery.__name__, - ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base model first -conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) -# For nested models, need to replace nested dict with registered model -paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy() -paginated_conversation_variable_fields_copy["data"] = fields.List( - fields.Nested(conversation_variable_model), attribute="data" -) -paginated_conversation_variable_model = console_ns.model( - "PaginatedConversationVariable", paginated_conversation_variable_fields_copy +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + return value + try: + return serialize_value_type(value) + except Exception: + return serialize_value_type({"value_type": value}) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +register_schema_models( + console_ns, + ConversationVariablesQuery, + ConversationVariableResponse, + PaginatedConversationVariableResponse, ) @@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource): @console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__]) - @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) + @console_ns.response( + 200, + "Conversation variables retrieved successfully", + console_ns.models[PaginatedConversationVariableResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) - @marshal_with(paginated_conversation_variable_model) def get(self, app_model): args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource): with sessionmaker(db.engine, expire_on_commit=False).begin() as session: rows = session.scalars(stmt).all() - return { - "page": page, - "limit": page_size, - "total": len(rows), - "has_more": False, - "data": [ - { - "created_at": row.created_at, - "updated_at": row.updated_at, - **row.to_variable().model_dump(), - } - for row in rows - ], - } + response = PaginatedConversationVariableResponse.model_validate( + { + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ + ConversationVariableResponse.model_validate( + { + "created_at": row.created_at, + "updated_at": row.updated_at, + **row.to_variable().model_dump(), + } + ) + for row in rows + ], + } + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 7101d5df7b..c720a5e074 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -20,6 +19,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 5b1abc98dc..d517f695b8 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -18,12 +18,6 @@ from models.enums import AppMCPServerStatus from models.model import AppMCPServer -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") parameters: dict[str, Any] = Field(..., description="Server parameters configuration") @@ -36,19 +30,25 @@ class MCPServerUpdatePayload(BaseModel): status: str | None = Field(default=None, description="Server status") +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + class AppMCPServerResponse(ResponseModel): id: str name: str server_code: str description: str - status: str + status: AppMCPServerStatus parameters: dict[str, Any] | list[Any] | str created_at: int | None = None updated_at: int | None = None @field_validator("parameters", mode="before") @classmethod - def _parse_json_string(cls, value: Any) -> Any: + def _normalize_parameters(cls, value: Any) -> Any: if isinstance(value, str): try: return json.loads(value) @@ -70,7 +70,9 @@ class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @console_ns.doc(description="Get MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @login_required @account_initialization_required @setup_required @@ -85,7 +87,9 @@ class AppMCPServerController(Resource): @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) - @console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @get_app_model @@ -111,13 +115,15 @@ class AppMCPServerController(Resource): ) db.session.add(server) db.session.commit() - return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201 @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) - @console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @get_app_model @@ -154,7 +160,7 @@ class AppMCPServerRefreshController(Resource): @console_ns.doc("refresh_app_mcp_server") @console_ns.doc(description="Refresh MCP server configuration and regenerate server code") @console_ns.doc(params={"server_id": "Server ID"}) - @console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @setup_required diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 5a19544eab..44e19b57db 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,9 +1,9 @@ import logging +from datetime import datetime from typing import Literal from flask import request -from flask_restx import Resource, fields, marshal_with -from graphon.model_runtime.errors.invoke import InvokeError +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound @@ -25,10 +25,22 @@ from controllers.console.wraps import ( setup_required, ) from core.app.entities.app_invoke_entities import InvokeFrom +from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from extensions.ext_database import db -from fields.raws import FilesContainedField -from libs.helper import TimestampField, uuid_value +from fields.base import ResponseModel +from fields.conversation_fields import ( + AgentThought, + ConversationAnnotation, + ConversationAnnotationHitHistory, + Feedback, + JSONValue, + MessageFile, + format_files_contained, + to_timestamp, +) +from graphon.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required from models.enums import FeedbackFromSource, FeedbackRating @@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel): data: list[str] = Field(description="Suggested question") +class MessageDetailResponse(ResponseModel): + id: str + conversation_id: str + inputs: dict[str, JSONValue] + query: str + message: JSONValue | None = None + message_tokens: int | None = None + answer: str = Field(validation_alias="re_sign_file_url_answer") + answer_tokens: int | None = None + provider_response_latency: float | None = None + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + feedbacks: list[Feedback] = Field(default_factory=list) + workflow_run_id: str | None = None + annotation: ConversationAnnotation | None = None + annotation_hit_history: ConversationAnnotationHitHistory | None = None + created_at: int | None = None + agent_thoughts: list[AgentThought] = Field(default_factory=list) + message_files: list[MessageFile] = Field(default_factory=list) + extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list) + metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict") + status: str + error: str | None = None + parent_message_id: str | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class MessageInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[MessageDetailResponse] + + register_schema_models( console_ns, ChatMessagesQuery, @@ -105,124 +162,8 @@ register_schema_models( FeedbackExportQuery, AnnotationCountResponse, SuggestedQuestionsResponse, -) - -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model( - "SimpleAccount", - { - "id": fields.String, - "name": fields.String, - "email": fields.String, - }, -) - -message_file_model = console_ns.model( - "MessageFile", - { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), - }, -) - -agent_thought_model = console_ns.model( - "AgentThought", - { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), - }, -) - -# Models that depend on simple_account_model -feedback_model = console_ns.model( - "Feedback", - { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_model, allow_null=True), - }, -) - -annotation_model = console_ns.model( - "Annotation", - { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -annotation_hit_history_model = console_ns.model( - "AnnotationHitHistory", - { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -# Message detail model that depends on multiple models -message_detail_model = console_ns.model( - "MessageDetail", - { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_model)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_model, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "message_files": fields.List(fields.Nested(message_file_model)), - "extra_contents": fields.List(fields.Raw), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - }, -) - -# Message infinite scroll pagination model -message_infinite_scroll_pagination_model = console_ns.model( - "MessageInfiniteScrollPagination", - { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_detail_model)), - }, + MessageDetailResponse, + MessageInfiniteScrollPaginationResponse, ) @@ -232,13 +173,12 @@ class ChatMessageListApi(Resource): @console_ns.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__]) - @console_ns.response(200, "Success", message_infinite_scroll_pagination_model) + @console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__]) @console_ns.response(404, "Conversation not found") @login_required @account_initialization_required @setup_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) @@ -298,7 +238,10 @@ class ChatMessageListApi(Resource): history_messages = list(reversed(history_messages)) attach_message_extra_contents(history_messages) - return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) + return MessageInfiniteScrollPaginationResponse.model_validate( + InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/apps//feedbacks") @@ -468,13 +411,12 @@ class MessageApi(Resource): @console_ns.doc("get_message") @console_ns.doc(description="Get message details by ID") @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) - @console_ns.response(200, "Message retrieved successfully", message_detail_model) + @console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__]) @console_ns.response(404, "Message not found") @get_app_model @setup_required @login_required @account_initialization_required - @marshal_with(message_detail_model) def get(self, app_model, message_id: str): message_id = str(message_id) @@ -486,4 +428,4 @@ class MessageApi(Resource): raise NotFound("Message Not Exists.") attach_message_extra_contents([message]) - return message + return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index da8d25c2eb..478f783eb0 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,11 +4,7 @@ from collections.abc import Sequence from typing import Any from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from graphon.enums import NodeType -from graphon.file import File -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder +from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -39,7 +35,13 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields +from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from graphon.enums import NodeType +from graphon.file import File +from graphon.file import helpers as file_helpers +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -47,6 +49,7 @@ from libs.login import current_account_with_tenant, login_required from models import App from models.model import AppMode from models.workflow import Workflow +from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -57,6 +60,7 @@ _file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" +MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -150,6 +154,14 @@ class ConvertToWorkflowPayload(BaseModel): icon_background: str | None = None +class WorkflowFeaturesPayload(BaseModel): + features: dict[str, Any] = Field(..., description="Workflow feature configuration") + + +class WorkflowOnlineUsersQuery(BaseModel): + app_ids: str = Field(..., description="Comma-separated app IDs") + + class DraftWorkflowTriggerRunPayload(BaseModel): node_id: str @@ -173,6 +185,8 @@ reg(DefaultBlockConfigQuery) reg(ConvertToWorkflowPayload) reg(WorkflowListQuery) reg(WorkflowUpdatePayload) +reg(WorkflowFeaturesPayload) +reg(WorkflowOnlineUsersQuery) reg(DraftWorkflowTriggerRunPayload) reg(DraftWorkflowTriggerRunAllPayload) @@ -931,6 +945,32 @@ class ConvertToWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/draft/features") +class WorkflowFeaturesApi(Resource): + """Update draft workflow features.""" + + @console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__]) + @console_ns.doc("update_workflow_features") + @console_ns.doc(description="Update draft workflow features") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Workflow features updated successfully") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + + args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {}) + features = args.features + + workflow_service = WorkflowService() + workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user) + + return {"result": "success"} + + @console_ns.route("/apps//workflows") class PublishedAllWorkflowApi(Resource): @console_ns.expect(console_ns.models[WorkflowListQuery.__name__]) @@ -942,7 +982,6 @@ class PublishedAllWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_pagination_model) @edit_permission_required def get(self, app_model: App): """ @@ -970,9 +1009,10 @@ class PublishedAllWorkflowApi(Resource): user_id=user_id, named_only=named_only, ) + serialized_workflows = marshal(workflows, workflow_fields_copy) return { - "items": workflows, + "items": serialized_workflows, "page": page, "limit": limit, "has_more": has_more, @@ -1340,3 +1380,62 @@ class DraftWorkflowTriggerRunAllApi(Resource): "status": "error", } ), 400 + + +@console_ns.route("/apps/workflows/online-users") +class WorkflowOnlineUsersApi(Resource): + @console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__]) + @console_ns.doc("get_workflow_online_users") + @console_ns.doc(description="Get workflow online users") + @setup_required + @login_required + @account_initialization_required + @marshal_with(online_user_list_fields) + def get(self): + args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip())) + if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS: + raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.") + + if not app_ids: + return {"data": []} + + _, current_tenant_id = current_account_with_tenant() + workflow_service = WorkflowService() + accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id) + + results = [] + for app_id in app_ids: + if app_id not in accessible_app_ids: + continue + + users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}") + + users = [] + for _, user_info_json in users_json.items(): + try: + user_info = json.loads(user_info_json) + except Exception: + continue + + if not isinstance(user_info, dict): + continue + + avatar = user_info.get("avatar") + if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")): + try: + user_info["avatar"] = file_helpers.get_signed_file_url(avatar) + except Exception as exc: + logger.warning( + "Failed to sign workflow online user avatar; using original value. " + "app_id=%s avatar=%s error=%s", + app_id, + avatar, + exc, + ) + + users.append(user_info) + results.append({"app_id": app_id, "users": users}) + + return {"data": results} diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 8ae6a78a62..4b39590235 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,27 +1,26 @@ from datetime import datetime +from typing import Any from dateutil.parser import isoparse from flask import request -from flask_restx import Resource, marshal_with -from graphon.enums import WorkflowExecutionStatus +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.workflow_app_log_fields import ( - build_workflow_app_log_pagination_model, - build_workflow_archived_log_pagination_model, -) +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount +from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode from services.workflow_app_service import WorkflowAppService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowAppLogQuery(BaseModel): keyword: str | None = Field(default=None, description="Search keyword for filtering logs") @@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel): raise ValueError("Invalid boolean value for detail") -console_ns.schema_model( - WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None -# Register model for flask_restx to avoid dict type issues in Swagger -workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) -workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowRunForArchivedLogResponse(ResponseModel): + id: str + status: str | None = None + triggered_from: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: Any = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowArchivedLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForArchivedLogResponse | None = None + trigger_metadata: Any = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +class WorkflowArchivedLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowArchivedLogPartialResponse] + + +register_schema_models( + console_ns, + WorkflowAppLogQuery, + WorkflowRunForLogResponse, + WorkflowRunForArchivedLogResponse, + WorkflowAppLogPartialResponse, + WorkflowArchivedLogPartialResponse, + WorkflowAppLogPaginationResponse, + WorkflowArchivedLogPaginationResponse, +) @console_ns.route("/apps//workflow-app-logs") @@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource): @console_ns.doc(description="Get workflow application execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) + @console_ns.response( + 200, + "Workflow app logs retrieved successfully", + console_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_app_log_pagination_model) def get(self, app_model: App): """ Get workflow app logs @@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return WorkflowAppLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/apps//workflow-archived-logs") @@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource): @console_ns.doc(description="Get workflow archived execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @console_ns.response( + 200, + "Workflow archived logs retrieved successfully", + console_ns.models[WorkflowArchivedLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_archived_log_pagination_model) def get(self, app_model: App): """ Get workflow archived logs @@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource): limit=args.limit, ) - return workflow_app_log_pagination + return WorkflowArchivedLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py new file mode 100644 index 0000000000..e7c3e982a6 --- /dev/null +++ b/api/controllers/console/app/workflow_comment.py @@ -0,0 +1,335 @@ +import logging + +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, TypeAdapter + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.member_fields import AccountWithRole +from fields.workflow_comment_fields import ( + workflow_comment_basic_fields, + workflow_comment_create_fields, + workflow_comment_detail_fields, + workflow_comment_reply_create_fields, + workflow_comment_reply_update_fields, + workflow_comment_resolve_fields, + workflow_comment_update_fields, +) +from libs.login import current_user, login_required +from models import App +from services.account_service import TenantService +from services.workflow_comment_service import WorkflowCommentService + +logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowCommentCreatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float = Field(..., description="Comment X position") + position_y: float = Field(..., description="Comment Y position") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentUpdatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float | None = Field(default=None, description="Comment X position") + position_y: float | None = Field(default=None, description="Comment Y position") + mentioned_user_ids: list[str] | None = Field( + default=None, + description="Mentioned user IDs. Omit to keep existing mentions.", + ) + + +class WorkflowCommentReplyPayload(BaseModel): + content: str = Field(..., description="Reply content") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentMentionUsersPayload(BaseModel): + users: list[AccountWithRole] + + +for model in ( + WorkflowCommentCreatePayload, + WorkflowCommentUpdatePayload, + WorkflowCommentReplyPayload, +): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload) + +workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) +workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) +workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields) +workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields) +workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields) +workflow_comment_reply_create_model = console_ns.model( + "WorkflowCommentReplyCreate", workflow_comment_reply_create_fields +) +workflow_comment_reply_update_model = console_ns.model( + "WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields +) + + +@console_ns.route("/apps//workflow/comments") +class WorkflowCommentListApi(Resource): + """API for listing and creating workflow comments.""" + + @console_ns.doc("list_workflow_comments") + @console_ns.doc(description="Get all comments for a workflow") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_basic_model, envelope="data") + def get(self, app_model: App): + """Get all comments for a workflow.""" + comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) + + return comments + + @console_ns.doc("create_workflow_comment") + @console_ns.doc(description="Create a new workflow comment") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__]) + @console_ns.response(201, "Comment created successfully", workflow_comment_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_create_model) + @edit_permission_required + def post(self, app_model: App): + """Create a new workflow comment.""" + payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + created_by=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments/") +class WorkflowCommentDetailApi(Resource): + """API for managing individual workflow comments.""" + + @console_ns.doc("get_workflow_comment") + @console_ns.doc(description="Get a specific workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_detail_model) + def get(self, app_model: App, comment_id: str): + """Get a specific workflow comment.""" + comment = WorkflowCommentService.get_comment( + tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id + ) + + return comment + + @console_ns.doc("update_workflow_comment") + @console_ns.doc(description="Update a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__]) + @console_ns.response(200, "Comment updated successfully", workflow_comment_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str): + """Update a workflow comment.""" + payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.update_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result + + @console_ns.doc("delete_workflow_comment") + @console_ns.doc(description="Delete a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(204, "Comment deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str): + """Delete a workflow comment.""" + WorkflowCommentService.delete_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments//resolve") +class WorkflowCommentResolveApi(Resource): + """API for resolving and reopening workflow comments.""" + + @console_ns.doc("resolve_workflow_comment") + @console_ns.doc(description="Resolve a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_resolve_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Resolve a workflow comment.""" + comment = WorkflowCommentService.resolve_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return comment + + +@console_ns.route("/apps//workflow/comments//replies") +class WorkflowCommentReplyApi(Resource): + """API for managing comment replies.""" + + @console_ns.doc("create_workflow_comment_reply") + @console_ns.doc(description="Add a reply to a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_create_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Add a reply to a workflow comment.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_reply( + comment_id=comment_id, + content=payload.content, + created_by=current_user.id, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments//replies/") +class WorkflowCommentReplyDetailApi(Resource): + """API for managing individual comment replies.""" + + @console_ns.doc("update_workflow_comment_reply") + @console_ns.doc(description="Update a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str, reply_id: str): + """Update a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + reply = WorkflowCommentService.update_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + content=payload.content, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return reply + + @console_ns.doc("delete_workflow_comment_reply") + @console_ns.doc(description="Delete a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.response(204, "Reply deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str, reply_id: str): + """Delete a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + WorkflowCommentService.delete_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments/mention-users") +class WorkflowCommentMentionUsersApi(Resource): + """API for getting mentionable users for workflow comments.""" + + @console_ns.doc("workflow_comment_mention_users") + @console_ns.doc(description="Get all users in current tenant for mentions") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response( + 200, "Mentionable users retrieved successfully", console_ns.models[WorkflowCommentMentionUsersPayload.__name__] + ) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App): + """Get all users in current tenant for mentions.""" + members = TenantService.get_tenant_members(current_user.current_tenant) + users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = WorkflowCommentMentionUsersPayload(users=users) + return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 657e794ac4..e32ba5f66c 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -5,10 +5,6 @@ from typing import Any, TypedDict from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.file import helpers as file_helpers -from graphon.variables.segment_group import SegmentGroup -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker @@ -22,8 +18,13 @@ from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.file_access import DatabaseFileAccessController from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db +from factories import variable_factory from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable @@ -45,6 +46,16 @@ class WorkflowDraftVariableUpdatePayload(BaseModel): value: Any | None = Field(default=None, description="Variable value") +class ConversationVariableUpdatePayload(BaseModel): + conversation_variables: list[dict[str, Any]] = Field( + ..., description="Conversation variables for the draft workflow" + ) + + +class EnvironmentVariableUpdatePayload(BaseModel): + environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow") + + console_ns.schema_model( WorkflowDraftVariableListQuery.__name__, WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), @@ -53,6 +64,14 @@ console_ns.schema_model( WorkflowDraftVariableUpdatePayload.__name__, WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) +console_ns.schema_model( + ConversationVariableUpdatePayload.__name__, + ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + EnvironmentVariableUpdatePayload.__name__, + EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) def _convert_values_to_json_serializable_object(value: Segment): @@ -83,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable): def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type - return value_type.exposed_type().value + return str(value_type.exposed_type()) class FullContentDict(TypedDict): @@ -103,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict result: FullContentDict = { "size_bytes": variable_file.size, - "value_type": variable_file.value_type.exposed_type().value, + "value_type": str(variable_file.value_type.exposed_type()), "length": variable_file.length, "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True), } @@ -510,6 +529,34 @@ class ConversationVariableCollectionApi(Resource): db.session.commit() return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + @console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__]) + @console_ns.doc("update_conversation_variables") + @console_ns.doc(description="Update conversation variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Conversation variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + def post(self, app_model: App): + payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + conversation_variables_list = payload.conversation_variables + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service.update_draft_workflow_conversation_variables( + app_model=app_model, + account=current_user, + conversation_variables=conversation_variables, + ) + + return {"result": "success"} + @console_ns.route("/apps//workflows/draft/system-variables") class SystemVariableCollectionApi(Resource): @@ -551,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.exposed_type().value, + "value_type": str(v.value_type.exposed_type()), "value": v.value, # Do not track edited for env vars. "edited": False, @@ -561,3 +608,31 @@ class EnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} + + @console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__]) + @console_ns.doc("update_environment_variables") + @console_ns.doc(description="Update environment variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Environment variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + environment_variables_list = payload.environment_variables + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + + workflow_service.update_draft_workflow_environment_variables( + app_model=app_model, + account=current_user, + environment_variables=environment_variables, + ) + + return {"result": "success"} diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index a1a075be71..6748d95d6b 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,8 +3,6 @@ from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -28,6 +26,8 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index c457684c15..a6715fa200 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -1,16 +1,17 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel +from flask_restx import Resource +from pydantic import BaseModel, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from configs import dify_config -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from extensions.ext_database import db -from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields +from fields.base import ResponseModel from libs.login import current_user, login_required from models.enums import AppTriggerStatus from models.model import Account, App, AppMode @@ -21,15 +22,6 @@ from ..app.wraps import get_app_model from ..wraps import account_initialization_required, edit_permission_required, setup_required logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - -trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields) - -triggers_list_fields_copy = triggers_list_fields.copy() -triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model)) -triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy) - -webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields) class Parser(BaseModel): @@ -41,10 +33,52 @@ class ParserEnable(BaseModel): enable_trigger: bool -console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class WorkflowTriggerResponse(ResponseModel): + id: str + trigger_type: str + title: str + node_id: str + provider_name: str + icon: str + status: str + created_at: datetime | None = None + updated_at: datetime | None = None -console_ns.schema_model( - ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + @field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +class WorkflowTriggerListResponse(ResponseModel): + data: list[WorkflowTriggerResponse] + + +class WebhookTriggerResponse(ResponseModel): + id: str + webhook_id: str + webhook_url: str + webhook_debug_url: str + node_id: str + created_at: datetime | None = None + + @field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +register_schema_models( + console_ns, + Parser, + ParserEnable, + WorkflowTriggerResponse, + WorkflowTriggerListResponse, + WebhookTriggerResponse, ) @@ -57,7 +91,7 @@ class WebhookTriggerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(webhook_trigger_model) + @console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__]) def get(self, app_model: App): """Get webhook trigger for a node""" args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource): if not webhook_trigger: raise NotFound("Webhook trigger not found for this node") - return webhook_trigger + return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//triggers") @@ -89,7 +123,7 @@ class AppTriggersApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(triggers_list_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__]) def get(self, app_model: App): """Get app triggers list""" assert isinstance(current_user, Account) @@ -118,7 +152,9 @@ class AppTriggersApi(Resource): else: trigger.icon = "" # type: ignore - return {"data": triggers} + return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//trigger-enable") @@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(trigger_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__]) def post(self, app_model: App): """Update app trigger (enable/disable)""" args = ParserEnable.model_validate(console_ns.payload) @@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource): else: trigger.icon = "" # type: ignore - return trigger + return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index e6316bfd62..f7061f820f 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,3 +1,5 @@ +from typing import Any + from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator @@ -40,7 +42,7 @@ class ActivatePayload(BaseModel): class ActivationCheckResponse(BaseModel): is_valid: bool = Field(description="Whether token is valid") - data: dict | None = Field(default=None, description="Activation data if valid") + data: dict[str, Any] | None = Field(default=None, description="Activation data if valid") class ActivationResponse(BaseModel): diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index b55cda4244..727428c8e7 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -5,11 +5,11 @@ from typing import Concatenate from flask import jsonify, request from flask.typing import ResponseReturnValue from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index b2a905366a..d001dfba64 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,7 +2,6 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound @@ -49,7 +48,9 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required +from libs.url_utils import normalize_api_base_url from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum from models.enums import ApiTokenType, SegmentStatus @@ -889,7 +890,8 @@ class DatasetApiBaseUrlApi(Resource): @login_required @account_initialization_required def get(self): - return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return {"api_base_url": normalize_api_base_url(base)} @console_ns.route("/datasets/retrieval-setting") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index de8fe1c0e2..3372a967d9 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,20 +3,19 @@ import logging from argparse import ArgumentTypeError from collections.abc import Sequence from contextlib import ExitStack +from datetime import datetime from typing import Any, Literal, cast import sqlalchemy as sa from flask import request, send_file -from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from pydantic import BaseModel, Field +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload -from controllers.common.schema import get_or_create_model, register_schema_models +from controllers.common.schema import register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -31,14 +30,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db -from fields.dataset_fields import dataset_fields +from fields.base import ResponseModel from fields.document_fields import ( - dataset_and_document_fields, document_fields, - document_metadata_fields, document_status_fields, document_with_segments_fields, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile @@ -72,27 +71,100 @@ from ..wraps import ( logger = logging.getLogger(__name__) -# Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = get_or_create_model("Dataset", dataset_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) -document_fields_copy = document_fields.copy() -document_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_model = get_or_create_model("Document", document_fields_copy) +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) -document_with_segments_fields_copy = document_with_segments_fields.copy() -document_with_segments_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) -dataset_and_document_fields_copy = dataset_and_document_fields.copy() -dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) -dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +class DatasetResponse(ResponseModel): + id: str + name: str + description: str | None = None + permission: str | None = None + data_source_type: str | None = None + indexing_technique: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("data_source_type", "indexing_technique", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentMetadataResponse(ResponseModel): + id: str + name: str + type: str + value: str | None = None + + +class DocumentResponse(ResponseModel): + id: str + position: int | None = None + data_source_type: str | None = None + data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict") + data_source_detail_dict: Any = None + dataset_process_rule_id: str | None = None + name: str + created_from: str | None = None + created_by: str | None = None + created_at: int | None = None + tokens: int | None = None + indexing_status: str | None = None + error: str | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + archived: bool | None = None + display_status: str | None = None + word_count: int | None = None + hit_count: int | None = None + doc_form: str | None = None + doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details") + summary_index_status: str | None = None + need_summary: bool | None = None + + @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("doc_metadata", mode="before") + @classmethod + def _normalize_doc_metadata(cls, value: Any) -> list[Any]: + if value is None: + return [] + return value + + @field_validator("created_at", "disabled_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentWithSegmentsResponse(DocumentResponse): + process_rule_dict: Any = None + completed_segments: int | None = None + total_segments: int | None = None + + +class DatasetAndDocumentResponse(ResponseModel): + dataset: DatasetResponse + documents: list[DocumentResponse] + batch: str class DocumentRetryPayload(BaseModel): @@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel): document_list: list[str] +class DocumentMetadataUpdatePayload(BaseModel): + doc_type: str | None = None + doc_metadata: Any = None + + class DocumentDatasetListParam(BaseModel): page: int = Field(1, title="Page", description="Page number.") limit: int = Field(20, title="Limit", description="Page size.") @@ -124,7 +201,13 @@ register_schema_models( DocumentRetryPayload, DocumentRenamePayload, GenerateSummaryPayload, + DocumentMetadataUpdatePayload, DocumentBatchDownloadZipPayload, + DatasetResponse, + DocumentMetadataResponse, + DocumentResponse, + DocumentWithSegmentsResponse, + DatasetAndDocumentResponse, ) @@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) + @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"dataset": dataset, "documents": documents, "batch": batch} + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @setup_required @login_required @@ -426,12 +511,13 @@ class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) + @console_ns.response( + 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__] + ) @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self): @@ -479,9 +565,9 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = {"dataset": dataset, "documents": documents, "batch": batch} - - return response + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/datasets//documents//indexing-estimate") @@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource): @console_ns.doc("update_document_metadata") @console_ns.doc(description="Update document metadata") @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @console_ns.expect( - console_ns.model( - "UpdateDocumentMetadataRequest", - { - "doc_type": fields.String(description="Document type"), - "doc_metadata": fields.Raw(description="Document metadata"), - }, - ) - ) + @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__]) @console_ns.response(200, "Document metadata updated successfully") @console_ns.response(404, "Document not found") @console_ns.response(403, "Permission denied") @@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = request.get_json() + req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) - doc_type = req_data.get("doc_type") - doc_metadata = req_data.get("doc_metadata") + doc_type = req_data.doc_type + doc_metadata = req_data.doc_metadata # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1026,7 +1104,7 @@ class DocumentMetadataApi(DocumentResource): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) + metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": @@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_model) + @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return document + return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json") @console_ns.route("/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 354c299bef..2647bb1f5a 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,7 +2,6 @@ import uuid from flask import request from flask_restx import Resource, marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -32,6 +31,7 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index e62be13c2f..36a7a4bb0e 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,13 +1,13 @@ -from flask_restx import Resource, fields +from __future__ import annotations -from controllers.common.schema import register_schema_model -from fields.hit_testing_fields import ( - child_chunk_fields, - document_fields, - files_fields, - hit_testing_record_fields, - segment_fields, -) +from datetime import datetime +from typing import Any + +from flask_restx import Resource +from pydantic import Field, field_validator + +from controllers.common.schema import register_schema_models +from fields.base import ResponseModel from libs.login import login_required from .. import console_ns @@ -18,39 +18,92 @@ from ..wraps import ( setup_required, ) -register_schema_model(console_ns, HitTestingPayload) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def _get_or_create_model(model_name: str, field_def): - """Get or create a flask_restx model to avoid dict type issues in Swagger.""" - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing +class HitTestingDocument(ResponseModel): + id: str | None = None + data_source_type: str | None = None + name: str | None = None + doc_type: str | None = None + doc_metadata: Any | None = None -# Register models for flask_restx to avoid dict type issues in Swagger -document_model = _get_or_create_model("HitTestingDocument", document_fields) +class HitTestingSegment(ResponseModel): + id: str | None = None + position: int | None = None + document_id: str | None = None + content: str | None = None + sign_content: str | None = None + answer: str | None = None + word_count: int | None = None + tokens: int | None = None + keywords: list[str] = Field(default_factory=list) + index_node_id: str | None = None + index_node_hash: str | None = None + hit_count: int | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + status: str | None = None + created_by: str | None = None + created_at: int | None = None + indexing_at: int | None = None + completed_at: int | None = None + error: str | None = None + stopped_at: int | None = None + document: HitTestingDocument | None = None -segment_fields_copy = segment_fields.copy() -segment_fields_copy["document"] = fields.Nested(document_model) -segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) -files_model = _get_or_create_model("HitTestingFile", files_fields) -hit_testing_record_fields_copy = hit_testing_record_fields.copy() -hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) -hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) -hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) -hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) +class HitTestingChildChunk(ResponseModel): + id: str | None = None + content: str | None = None + position: int | None = None + score: float | None = None -# Response model for hit testing API -hit_testing_response_fields = { - "query": fields.String, - "records": fields.List(fields.Nested(hit_testing_record_model)), -} -hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + +class HitTestingFile(ResponseModel): + id: str | None = None + name: str | None = None + size: int | None = None + extension: str | None = None + mime_type: str | None = None + source_url: str | None = None + + +class HitTestingRecord(ResponseModel): + segment: HitTestingSegment | None = None + child_chunks: list[HitTestingChildChunk] = Field(default_factory=list) + score: float | None = None + tsne_position: Any | None = None + files: list[HitTestingFile] = Field(default_factory=list) + summary: str | None = None + + +class HitTestingResponse(ResponseModel): + query: str + records: list[HitTestingRecord] = Field(default_factory=list) + + +register_schema_models( + console_ns, + HitTestingPayload, + HitTestingDocument, + HitTestingSegment, + HitTestingChildChunk, + HitTestingFile, + HitTestingRecord, + HitTestingResponse, +) @console_ns.route("/datasets//hit-testing") @@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) + @console_ns.response( + 200, + "Hit testing completed successfully", + model=console_ns.models[HitTestingResponse.__name__], + ) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required @@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): args = payload.model_dump(exclude_none=True) self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) + return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 8fb3699849..699fa599c8 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,7 +2,6 @@ import logging from typing import Any from flask_restx import marshal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -21,6 +20,7 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index bdf83b991e..fd0a8b33bc 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,8 +2,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -12,6 +10,8 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 3549f9542d..b31d73f27d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -4,7 +4,6 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -28,6 +27,7 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a8077d9eb0..ee146e8287 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,7 +4,6 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -41,6 +40,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index a37077af42..ab660d9dc3 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,7 +1,6 @@ import logging from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import services @@ -20,6 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index eacd7332fe..ccdccceaa6 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,7 +2,6 @@ import logging from typing import Any, Literal from uuid import UUID -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 0740dd0e24..2d9a997fbf 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,21 +1,24 @@ import logging +from datetime import datetime from typing import Any from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, computed_field, field_validator from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields +from fields.base import ResponseModel +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp +from models.model import IconType from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel): logger = logging.getLogger(__name__) -app_model = get_or_create_model("InstalledAppInfo", app_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) -installed_app_fields_copy = installed_app_fields.copy() -installed_app_fields_copy["app"] = fields.Nested(app_model) -installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy) -installed_app_list_fields_copy = installed_app_list_fields.copy() -installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model)) -installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy) +def _safe_primitive(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool, datetime)): + return value + return None + + +class InstalledAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + use_icon_as_answer_icon: bool | None = None + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_like(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class InstalledAppResponse(ResponseModel): + id: str + app: InstalledAppInfoResponse + app_owner_tenant_id: str + is_pinned: bool + last_used_at: int | None = None + editable: bool + uninstallable: bool + + @field_validator("app", mode="before") + @classmethod + def _normalize_app(cls, value: Any) -> Any: + if isinstance(value, dict): + return value + return { + "id": _safe_primitive(getattr(value, "id", "")) or "", + "name": _safe_primitive(getattr(value, "name", None)), + "mode": _safe_primitive(getattr(value, "mode", None)), + "icon_type": _safe_primitive(getattr(value, "icon_type", None)), + "icon": _safe_primitive(getattr(value, "icon", None)), + "icon_background": _safe_primitive(getattr(value, "icon_background", None)), + "use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)), + } + + @field_validator("last_used_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class InstalledAppListResponse(ResponseModel): + installed_apps: list[InstalledAppResponse] + + +register_schema_models( + console_ns, + InstalledAppCreatePayload, + InstalledAppUpdatePayload, + InstalledAppsListQuery, + InstalledAppInfoResponse, + InstalledAppResponse, + InstalledAppListResponse, +) @console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required - @marshal_with(installed_app_list_model) + @console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__]) def get(self): query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() @@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource): ) ) - return {"installed_apps": installed_app_list} + return InstalledAppListResponse.model_validate( + {"installed_apps": installed_app_list}, from_attributes=True + ).model_dump(mode="json") @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 64d55d7ca3..209667d1d0 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,7 +2,6 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -25,6 +24,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.enums import FeedbackRating diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index c9920c97cf..55bd679b48 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,66 +1,83 @@ +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, computed_field, field_validator from constants.languages import languages -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required -from libs.helper import AppIconUrlField +from fields.base import ResponseModel +from libs.helper import build_icon_url from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService -app_fields = { - "id": fields.String, - "name": fields.String, - "mode": fields.String, - "icon": fields.String, - "icon_type": fields.String, - "icon_url": AppIconUrlField, - "icon_background": fields.String, -} - -app_model = get_or_create_model("RecommendedAppInfo", app_fields) - -recommended_app_fields = { - "app": fields.Nested(app_model, attribute="app"), - "app_id": fields.String, - "description": fields.String(attribute="description"), - "copyright": fields.String, - "privacy_policy": fields.String, - "custom_disclaimer": fields.String, - "category": fields.String, - "position": fields.Integer, - "is_listed": fields.Boolean, - "can_trial": fields.Boolean, -} - -recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields) - -recommended_app_list_fields = { - "recommended_apps": fields.List(fields.Nested(recommended_app_model)), - "categories": fields.List(fields.String), -} - -recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields) - class RecommendedAppsQuery(BaseModel): language: str | None = Field(default=None) -console_ns.schema_model( - RecommendedAppsQuery.__name__, - RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"), +class RecommendedAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon: str | None = None + icon_type: str | None = None + icon_background: str | None = None + + @staticmethod + def _normalize_enum_like(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> str | None: + return cls._normalize_enum_like(value) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return build_icon_url(self.icon_type, self.icon) + + +class RecommendedAppResponse(ResponseModel): + app: RecommendedAppInfoResponse | None = None + app_id: str + description: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + category: str | None = None + position: int | None = None + is_listed: bool | None = None + can_trial: bool | None = None + + +class RecommendedAppListResponse(ResponseModel): + recommended_apps: list[RecommendedAppResponse] + categories: list[str] + + +register_schema_models( + console_ns, + RecommendedAppsQuery, + RecommendedAppInfoResponse, + RecommendedAppResponse, + RecommendedAppListResponse, ) @console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) + @console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__]) @login_required @account_initialization_required - @marshal_with(recommended_app_list_model) def get(self): # language args args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource): else: language_prefix = languages[0] - return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) + return RecommendedAppListResponse.model_validate( + RecommendedAppService.get_recommended_apps_and_categories(language_prefix), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/explore/apps/") diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index e432574434..1456301a24 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,8 +3,6 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -61,6 +59,8 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user @@ -169,6 +169,7 @@ console_ns.schema_model( class TrialAppWorkflowRunApi(TrialAppResource): + @trial_feature_enable @console_ns.expect(console_ns.models[WorkflowRunRequest.__name__]) def post(self, trial_app): """ @@ -210,6 +211,7 @@ class TrialAppWorkflowRunApi(TrialAppResource): class TrialAppWorkflowTaskStopApi(TrialAppResource): + @trial_feature_enable def post(self, trial_app, task_id: str): """ Stop workflow task @@ -290,7 +292,6 @@ class TrialChatApi(TrialAppResource): class TrialMessageSuggestedQuestionApi(TrialAppResource): - @trial_feature_enable def get(self, trial_app, message_id): app_model = trial_app app_mode = AppMode.value_of(app_model.mode) @@ -470,7 +471,6 @@ class TrialCompletionApi(TrialAppResource): class TrialSitApi(Resource): """Resource for trial app sites.""" - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): """Retrieve app site info. @@ -492,7 +492,6 @@ class TrialSitApi(Resource): class TrialAppParameterApi(Resource): """Resource for app variables.""" - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): """Retrieve app parameters.""" @@ -521,7 +520,6 @@ class TrialAppParameterApi(Resource): class AppApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) @marshal_with(app_detail_with_site_model) def get(self, app_model): @@ -534,7 +532,6 @@ class AppApi(Resource): class AppWorkflowApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) @marshal_with(workflow_model) def get(self, app_model): @@ -547,7 +544,6 @@ class AppWorkflowApi(Resource): class DatasetListApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): page = request.args.get("page", default=1, type=int) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index da88de6776..438cce4fd8 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,7 +1,5 @@ import logging -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -23,6 +21,8 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index efa46c9779..7a6356d052 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,15 +1,18 @@ +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE -from fields.api_based_extension_fields import api_based_extension_fields +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService -from ..common.schema import register_schema_models +from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models from . import console_ns from .wraps import account_initialization_required, setup_required @@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel): api_key: str = Field(description="API key for authentication") -register_schema_models(console_ns, APIBasedExtensionPayload) +class CodeBasedExtensionResponse(ResponseModel): + module: str = Field(description="Module name") + data: Any = Field(description="Extension data") -api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) +def _mask_api_key(api_key: str) -> str: + if not api_key: + return api_key + if len(api_key) <= 8: + return api_key[0] + "******" + api_key[-1] + return api_key[:3] + "******" + api_key[-3:] -api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class APIBasedExtensionResponse(ResponseModel): + id: str + name: str + api_endpoint: str + api_key: str + created_at: int | None = None + + @field_validator("api_key", mode="before") + @classmethod + def _normalize_api_key(cls, value: str) -> str: + return _mask_api_key(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse) +console_ns.schema_model( + "APIBasedExtensionListResponse", + TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]: + return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json") @console_ns.route("/code-based-extension") @@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "CodeBasedExtensionResponse", - {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, - ), + console_ns.models[CodeBasedExtensionResponse.__name__], ) @setup_required @login_required @@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource): def get(self): query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)} + return CodeBasedExtensionResponse( + module=query.module, + data=CodeBasedExtensionService.get_code_based_extension(query.module), + ).model_dump(mode="json") @console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): @console_ns.doc("get_api_based_extensions") @console_ns.doc(description="Get all API-based extensions for current tenant") - @console_ns.response(200, "Success", api_based_extension_list_model) + @console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self): _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + return [ + _serialize_api_based_extension(extension) + for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + ] @console_ns.doc("create_api_based_extension") @console_ns.doc(description="Create a new API-based extension") @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(201, "Extension created successfully", api_based_extension_model) + @console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self): payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() @@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource): api_key=payload.api_key, ) - return APIBasedExtensionService.save(extension_data) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data)) @console_ns.route("/api-based-extension/") @@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource): @console_ns.doc("get_api_based_extension") @console_ns.doc(description="Get API-based extension by ID") @console_ns.doc(params={"id": "Extension ID"}) - @console_ns.response(200, "Success", api_based_extension_model) + @console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self, id): api_based_extension_id = str(id) _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + return _serialize_api_based_extension( + APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + ) @console_ns.doc("update_api_based_extension") @console_ns.doc(description="Update API-based extension") @console_ns.doc(params={"id": "Extension ID"}) @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(200, "Extension updated successfully", api_based_extension_model) + @console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self, id): api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() @@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource): if payload.api_key != HIDDEN_VALUE: extension_data_from_db.api_key = payload.api_key - return APIBasedExtensionService.save(extension_data_from_db) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db)) @console_ns.doc("delete_api_based_extension") @console_ns.doc(description="Delete API-based extension") diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 845af37365..79b3e6cc9f 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -8,10 +8,10 @@ from collections.abc import Generator from flask import Response, jsonify, request from flask_restx import Resource -from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError @@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App from models.enums import CreatorUserRole -from models.human_input import RecipientType from models.model import AppMode from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) -class HumanInputFormSubmitPayload(BaseModel): - inputs: dict - action: str - - def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() payload["expiration_time"] = int(form.expiration_time.timestamp()) @@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource): if form.tenant_id != current_tenant_id: raise NotFoundError("App not found") + @staticmethod + def _ensure_console_recipient_type(form: Form) -> None: + if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE): + raise NotFoundError("form not found") + @setup_required @login_required @account_initialization_required @@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource): raise NotFoundError(f"form not found, token={form_token}") self._ensure_console_access(form) - + self._ensure_console_recipient_type(form) recipient_type = form.recipient_type - if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: - raise NotFoundError(f"form not found, token={form_token}") # The type checker is not smart enought to validate the following invariant. # So we need to assert it manually. assert recipient_type is not None, "recipient_type cannot be None here." diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 551c86fd82..2a46d2250a 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,7 +2,6 @@ import urllib.parse import httpx from flask_restx import Resource -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -16,6 +15,7 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/socketio/__init__.py b/api/controllers/console/socketio/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/controllers/console/socketio/__init__.py @@ -0,0 +1 @@ + diff --git a/api/controllers/console/socketio/workflow.py b/api/controllers/console/socketio/workflow.py new file mode 100644 index 0000000000..b4f03593fd --- /dev/null +++ b/api/controllers/console/socketio/workflow.py @@ -0,0 +1,108 @@ +import logging +from collections.abc import Callable +from typing import cast + +from flask import Request as FlaskRequest + +from extensions.ext_socketio import sio +from libs.passport import PassportService +from libs.token import extract_access_token +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository +from services.account_service import AccountService +from services.workflow_collaboration_service import WorkflowCollaborationService + +repository = WorkflowCollaborationRepository() +collaboration_service = WorkflowCollaborationService(repository, sio) + + +def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]: + return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event)) + + +@_sio_on("connect") +def socket_connect(sid, environ, auth): + """ + WebSocket connect event, do authentication here. + """ + try: + request_environ = FlaskRequest(environ) + token = extract_access_token(request_environ) + except Exception: + logging.exception("Failed to extract token") + token = None + + if not token: + logging.warning("Socket connect rejected: missing token (sid=%s)", sid) + return False + + try: + decoded = PassportService().verify(token) + user_id = decoded.get("user_id") + if not user_id: + logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid) + return False + + with sio.app.app_context(): + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid) + return False + if not user.has_edit_permission: + logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid) + return False + + collaboration_service.save_socket_identity(sid, user) + return True + + except Exception: + logging.exception("Socket authentication failed") + return False + + +@_sio_on("user_connect") +def handle_user_connect(sid, data): + """ + Handle user connect event. Each session (tab) is treated as an independent collaborator. + """ + workflow_id = data.get("workflow_id") + if not workflow_id: + return {"msg": "workflow_id is required"}, 400 + + result = collaboration_service.authorize_and_join_workflow_room(workflow_id, sid) + if not result: + return {"msg": "unauthorized"}, 401 + + user_id, is_leader = result + return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader} + + +@_sio_on("disconnect") +def handle_disconnect(sid): + """ + Handle session disconnect event. Remove the specific session from online users. + """ + collaboration_service.disconnect_session(sid) + + +@_sio_on("collaboration_event") +def handle_collaboration_event(sid, data): + """ + Handle general collaboration events, include: + 1. mouse_move + 2. vars_and_features_update + 3. sync_request (ask leader to update graph) + 4. app_state_update + 5. mcp_server_update + 6. workflow_update + 7. comments_update + 8. node_panel_presence + """ + return collaboration_service.relay_collaboration_event(sid, data) + + +@_sio_on("graph_event") +def handle_graph_event(sid, data): + """ + Handle graph events - simple broadcast relay. + """ + return collaboration_service.relay_graph_event(sid, data) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 39b84d3869..f73e2da54e 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,13 +1,14 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.enums import TagType from services.tag_service import ( @@ -18,17 +19,6 @@ from services.tag_service import ( UpdateTagPayload, ) -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} - - -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) - class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) @@ -47,17 +37,47 @@ class TagBindingRemovePayload(BaseModel): type: TagType = Field(description="Tag type") +class TagBindingItemDeletePayload(BaseModel): + target_id: str = Field(description="Target ID to unbind tag from") + type: TagType = Field(description="Tag type") + + class TagListQueryParam(BaseModel): type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter") keyword: str | None = Field(None, description="Search keyword") +class TagResponse(ResponseModel): + id: str + name: str + type: str | None = None + binding_count: str | None = None + + @field_validator("type", mode="before") + @classmethod + def normalize_type(cls, value: TagType | str | None) -> str | None: + if value is None: + return None + if isinstance(value, TagType): + return value.value + return value + + @field_validator("binding_count", mode="before") + @classmethod + def normalize_binding_count(cls, value: int | str | None) -> str | None: + if value is None: + return None + return str(value) + + register_schema_models( console_ns, TagBasePayload, TagBindingPayload, TagBindingRemovePayload, + TagBindingItemDeletePayload, TagListQueryParam, + TagResponse, ) @@ -69,14 +89,18 @@ class TagListApi(Resource): @console_ns.doc( params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} ) - @marshal_with(dataset_tag_fields) + @console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])}) def get(self): _, current_tenant_id = current_account_with_tenant() raw_args = request.args.to_dict() param = TagListQueryParam.model_validate(raw_args) tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) - return tags, 200 + serialized_tags = [ + TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags + ] + + return serialized_tags, 200 @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @@ -91,7 +115,9 @@ class TagListApi(Resource): payload = TagBasePayload.model_validate(console_ns.payload or {}) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type)) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @@ -114,7 +140,9 @@ class TagUpdateDeleteApi(Resource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @@ -130,41 +158,107 @@ class TagUpdateDeleteApi(Resource): return "", 204 -@console_ns.route("/tag-bindings/create") -class TagBindingCreateApi(Resource): +def _require_tag_binding_edit_permission() -> None: + """ + Ensure the current account can edit tag bindings. + + Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant. + """ + current_user, _ = current_account_with_tenant() + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() + + +def _create_tag_bindings() -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission() + + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding( + TagBindingCreatePayload( + tag_ids=payload.tag_ids, + target_id=payload.target_id, + type=payload.type, + ) + ) + return {"result": "success"}, 200 + + +def _remove_tag_binding() -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission() + + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding( + TagBindingDeletePayload( + tag_id=payload.tag_id, + target_id=payload.target_id, + type=payload.type, + ) + ) + return {"result": "success"}, 200 + + +@console_ns.route("/tag-bindings") +class TagBindingCollectionApi(Resource): + """Canonical collection resource for tag binding creation.""" + + @console_ns.doc("create_tag_binding") @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() + return _create_tag_bindings() - payload = TagBindingPayload.model_validate(console_ns.payload or {}) - TagService.save_tag_binding( - TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type) + +@console_ns.route("/tag-bindings/") +class TagBindingItemApi(Resource): + """Canonical item resource for tag binding deletion.""" + + @console_ns.doc("delete_tag_binding") + @console_ns.doc(params={"id": "Tag ID"}) + @console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def delete(self, id): + _require_tag_binding_edit_permission() + payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding( + TagBindingDeletePayload( + tag_id=str(id), + target_id=payload.target_id, + type=payload.type, + ) ) - return {"result": "success"}, 200 +@console_ns.route("/tag-bindings/create") +class DeprecatedTagBindingCreateApi(Resource): + """Deprecated verb-based alias for tag binding creation.""" + + @console_ns.doc("create_tag_binding_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.") + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def post(self): + return _create_tag_bindings() + + @console_ns.route("/tag-bindings/remove") -class TagBindingDeleteApi(Resource): +class DeprecatedTagBindingRemoveApi(Resource): + """Deprecated verb-based alias for tag binding deletion.""" + + @console_ns.doc("delete_tag_binding_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.") @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() - - payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding( - TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type) - ) - - return {"result": "success"}, 200 + return _remove_tag_binding() diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 60f712e476..59dd29fdac 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -35,22 +35,24 @@ def plugin_permission_required( return view(*args, **kwargs) if install_required: - if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: - raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.install_permission: + case TenantPluginPermission.InstallPermission.NOBODY: raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: - pass + case TenantPluginPermission.InstallPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.InstallPermission.EVERYONE: + pass if debug_required: - if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: - raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.debug_permission: + case TenantPluginPermission.DebugPermission.NOBODY: raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: - pass + case TenantPluginPermission.DebugPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.DebugPermission.EVERYONE: + pass return view(*args, **kwargs) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index c35006a7ee..c01286cc59 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -5,7 +5,7 @@ from typing import Any, Literal import pytz from flask import request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select @@ -37,9 +37,11 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db +from fields.base import ResponseModel from fields.member_fields import Account as AccountResponse +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now -from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone +from libs.helper import EmailStr, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode from models.account import AccountStatus, InvitationCodeStatus @@ -74,6 +76,10 @@ class AccountAvatarPayload(BaseModel): avatar: str +class AccountAvatarQuery(BaseModel): + avatar: str = Field(..., description="Avatar file ID") + + class AccountInterfaceLanguagePayload(BaseModel): interface_language: str @@ -159,6 +165,7 @@ def reg(cls: type[BaseModel]): reg(AccountInitPayload) reg(AccountNamePayload) reg(AccountAvatarPayload) +reg(AccountAvatarQuery) reg(AccountInterfaceLanguagePayload) reg(AccountInterfaceThemePayload) reg(AccountTimezonePayload) @@ -178,17 +185,57 @@ def _serialize_account(account) -> dict[str, Any]: return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") -integrate_fields = { - "provider": fields.String, - "created_at": TimestampField, - "is_bound": fields.Boolean, - "link": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -integrate_model = console_ns.model("AccountIntegrate", integrate_fields) -integrate_list_model = console_ns.model( - "AccountIntegrateList", - {"data": fields.List(fields.Nested(integrate_model))}, + +class AccountIntegrateResponse(ResponseModel): + provider: str + created_at: int | None = None + is_bound: bool + link: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountIntegrateListResponse(ResponseModel): + data: list[AccountIntegrateResponse] + + +class EducationVerifyResponse(ResponseModel): + token: str | None = None + + +class EducationStatusResponse(ResponseModel): + result: bool | None = None + is_student: bool | None = None + expire_at: int | None = None + allow_refresh: bool | None = None + + @field_validator("expire_at", mode="before") + @classmethod + def _normalize_expire_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class EducationAutocompleteResponse(ResponseModel): + data: list[str] = Field(default_factory=list) + curr_page: int | None = None + has_next: bool | None = None + + +register_schema_models( + console_ns, + AccountIntegrateResponse, + AccountIntegrateListResponse, + EducationVerifyResponse, + EducationStatusResponse, + EducationAutocompleteResponse, ) @@ -268,6 +315,18 @@ class AccountNameApi(Resource): @console_ns.route("/account/avatar") class AccountAvatarApi(Resource): + @console_ns.expect(console_ns.models[AccountAvatarQuery.__name__]) + @console_ns.doc("get_account_avatar") + @console_ns.doc(description="Get account avatar url") + @setup_required + @login_required + @account_initialization_required + def get(self): + args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + avatar_url = file_helpers.get_signed_file_url(args.avatar) + return {"avatar_url": avatar_url} + @console_ns.expect(console_ns.models[AccountAvatarPayload.__name__]) @setup_required @login_required @@ -359,7 +418,7 @@ class AccountIntegrateApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__]) def get(self): account, _ = current_account_with_tenant() @@ -395,7 +454,9 @@ class AccountIntegrateApi(Resource): } ) - return {"data": integrate_data} + return AccountIntegrateListResponse( + data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data] + ).model_dump(mode="json") @console_ns.route("/account/delete/verify") @@ -447,31 +508,22 @@ class AccountDeleteUpdateFeedbackApi(Resource): @console_ns.route("/account/education/verify") class EducationVerifyApi(Resource): - verify_fields = { - "token": fields.String, - } - @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(verify_fields) + @console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - return BillingService.EducationIdentity.verify(account.id, account.email) + return EducationVerifyResponse.model_validate( + BillingService.EducationIdentity.verify(account.id, account.email) or {} + ).model_dump(mode="json") @console_ns.route("/account/education") class EducationApi(Resource): - status_fields = { - "result": fields.Boolean, - "is_student": fields.Boolean, - "expire_at": TimestampField, - "allow_refresh": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationActivatePayload.__name__]) @setup_required @login_required @@ -491,37 +543,33 @@ class EducationApi(Resource): @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(status_fields) + @console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - res = BillingService.EducationIdentity.status(account.id) + res = BillingService.EducationIdentity.status(account.id) or {} # convert expire_at to UTC timestamp from isoformat if res and "expire_at" in res: res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc) - return res + return EducationStatusResponse.model_validate(res).model_dump(mode="json") @console_ns.route("/account/education/autocomplete") class EducationAutoCompleteApi(Resource): - data_fields = { - "data": fields.List(fields.String), - "curr_page": fields.Integer, - "has_next": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__]) @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(data_fields) + @console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__]) def get(self): payload = request.args.to_dict(flat=True) args = EducationAutocompleteQuery.model_validate(payload) - return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) + return EducationAutocompleteResponse.model_validate( + BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {} + ).model_dump(mode="json") @console_ns.route("/account/change-email") @@ -547,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource): account = None user_email = None email_for_sending = args.email.lower() - if args.phase is not None and args.phase == "new_email": + # Default to the initial phase; any legacy/unexpected client input is + # coerced back to `old_email` so we never trust the caller to declare + # later phases without a verified predecessor token. + send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD + if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW: + send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW if args.token is None: raise InvalidTokenError() reset_data = AccountService.get_change_email_data(args.token) if reset_data is None: raise InvalidTokenError() + + # The token used to request a new-email code must come from the + # old-email verification step. This prevents the bypass described + # in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED: + raise InvalidTokenError() user_email = reset_data.get("email", "") if user_email.lower() != current_user.email.lower(): @@ -572,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource): email=email_for_sending, old_email=user_email, language=language, - phase=args.phase, + phase=send_phase, ) return {"result": "success", "data": token} @@ -607,12 +667,31 @@ class ChangeEmailCheckApi(Resource): AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() + # Only advance tokens that were minted by the matching send-code step; + # refuse tokens that have already progressed or lack a phase marker so + # the chain `old_email -> old_email_verified -> new_email -> new_email_verified` + # is strictly enforced. + phase_transitions = { + AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if not isinstance(token_phase, str): + raise InvalidTokenError() + refreshed_phase = phase_transitions.get(token_phase) + if refreshed_phase is None: + raise InvalidTokenError() + # Verified, revoke the first token AccountService.revoke_change_email_token(args.token) - # Refresh token data by generating a new token + # Refresh token data by generating a new token that carries the + # upgraded phase so later steps can check it. _, new_token = AccountService.generate_change_email_token( - user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} + user_email, + code=args.code, + old_email=token_data.get("old_email"), + additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase}, ) AccountService.reset_change_email_error_rate_limit(user_email) @@ -642,13 +721,29 @@ class ChangeEmailResetApi(Resource): if not reset_data: raise InvalidTokenError() - AccountService.revoke_change_email_token(args.token) + # Only tokens that completed both verification phases may be used to + # change the email. This closes GHSA-4q3w-q5mc-45rq where a token from + # the initial send-code step could be replayed directly here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED: + raise InvalidTokenError() + + # Bind the new email to the token that was mailed and verified, so a + # verified token cannot be reused with a different `new_email` value. + token_email = reset_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if normalized_token_email != normalized_new_email: + raise InvalidTokenError() old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() if current_user.email.lower() != old_email.lower(): raise AccountNotFound() + # Revoke only after all checks pass so failed attempts don't burn a + # legitimately verified token. + AccountService.revoke_change_email_token(args.token) + updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) AccountService.send_change_email_completed_notify_email( diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 3fdcbc4710..764f488755 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index b6b9deb1f9..d4be07382a 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,14 +1,22 @@ +"""Console workspace endpoint controllers. + +This module exposes workspace-scoped plugin endpoint management APIs. The +canonical write routes follow resource-oriented paths, while the historical +verb-based aliases stay available as deprecated resources so OpenAPI metadata +marks only the legacy paths as deprecated. +""" + from typing import Any from flask import request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService @@ -25,7 +33,12 @@ class EndpointIdPayload(BaseModel): endpoint_id: str -class EndpointUpdatePayload(EndpointIdPayload): +class EndpointUpdatePayload(BaseModel): + settings: dict[str, Any] + name: str = Field(min_length=1) + + +class LegacyEndpointUpdatePayload(EndpointIdPayload): settings: dict[str, Any] name: str = Field(min_length=1) @@ -76,6 +89,7 @@ register_schema_models( EndpointCreatePayload, EndpointIdPayload, EndpointUpdatePayload, + LegacyEndpointUpdatePayload, EndpointListQuery, EndpointListForPluginQuery, EndpointCreateResponse, @@ -88,8 +102,60 @@ register_schema_models( ) -@console_ns.route("/workspaces/current/endpoints/create") -class EndpointCreateApi(Resource): +def _create_endpoint() -> dict[str, bool]: + """Create a plugin endpoint for the current workspace.""" + user, tenant_id = current_account_with_tenant() + + args = EndpointCreatePayload.model_validate(console_ns.payload) + + try: + return { + "success": EndpointService.create_endpoint( + tenant_id=tenant_id, + user_id=user.id, + plugin_unique_identifier=args.plugin_unique_identifier, + name=args.name, + settings=args.settings, + ) + } + except PluginPermissionDeniedError as e: + raise ValueError(e.description) from e + + +def _update_endpoint(endpoint_id: str) -> dict[str, bool]: + """Update a plugin endpoint identified by the canonical path parameter.""" + user, tenant_id = current_account_with_tenant() + + args = EndpointUpdatePayload.model_validate(console_ns.payload) + + return { + "success": EndpointService.update_endpoint( + tenant_id=tenant_id, + user_id=user.id, + endpoint_id=endpoint_id, + name=args.name, + settings=args.settings, + ) + } + + +def _delete_endpoint(endpoint_id: str) -> dict[str, bool]: + """Delete a plugin endpoint identified by the canonical path parameter.""" + user, tenant_id = current_account_with_tenant() + + return { + "success": EndpointService.delete_endpoint( + tenant_id=tenant_id, + user_id=user.id, + endpoint_id=endpoint_id, + ) + } + + +@console_ns.route("/workspaces/current/endpoints") +class EndpointCollectionApi(Resource): + """Canonical collection resource for endpoint creation.""" + @console_ns.doc("create_endpoint") @console_ns.doc(description="Create a new plugin endpoint") @console_ns.expect(console_ns.models[EndpointCreatePayload.__name__]) @@ -104,22 +170,33 @@ class EndpointCreateApi(Resource): @is_admin_or_owner_required @account_initialization_required def post(self): - user, tenant_id = current_account_with_tenant() + return _create_endpoint() - args = EndpointCreatePayload.model_validate(console_ns.payload) - try: - return { - "success": EndpointService.create_endpoint( - tenant_id=tenant_id, - user_id=user.id, - plugin_unique_identifier=args.plugin_unique_identifier, - name=args.name, - settings=args.settings, - ) - } - except PluginPermissionDeniedError as e: - raise ValueError(e.description) from e +@console_ns.route("/workspaces/current/endpoints/create") +class DeprecatedEndpointCreateApi(Resource): + """Deprecated verb-based alias for endpoint creation.""" + + @console_ns.doc("create_endpoint_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc( + description=( + "Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead." + ) + ) + @console_ns.expect(console_ns.models[EndpointCreatePayload.__name__]) + @console_ns.response( + 200, + "Endpoint created successfully", + console_ns.models[EndpointCreateResponse.__name__], + ) + @console_ns.response(403, "Admin privileges required") + @setup_required + @login_required + @is_admin_or_owner_required + @account_initialization_required + def post(self): + return _create_endpoint() @console_ns.route("/workspaces/current/endpoints/list") @@ -190,10 +267,56 @@ class EndpointListForSinglePluginApi(Resource): ) -@console_ns.route("/workspaces/current/endpoints/delete") -class EndpointDeleteApi(Resource): +@console_ns.route("/workspaces/current/endpoints/") +class EndpointItemApi(Resource): + """Canonical item resource for endpoint updates and deletion.""" + @console_ns.doc("delete_endpoint") @console_ns.doc(description="Delete a plugin endpoint") + @console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}}) + @console_ns.response( + 200, + "Endpoint deleted successfully", + console_ns.models[EndpointDeleteResponse.__name__], + ) + @console_ns.response(403, "Admin privileges required") + @setup_required + @login_required + @is_admin_or_owner_required + @account_initialization_required + def delete(self, id: str): + return _delete_endpoint(endpoint_id=id) + + @console_ns.doc("update_endpoint") + @console_ns.doc(description="Update a plugin endpoint") + @console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__]) + @console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}}) + @console_ns.response( + 200, + "Endpoint updated successfully", + console_ns.models[EndpointUpdateResponse.__name__], + ) + @console_ns.response(403, "Admin privileges required") + @setup_required + @login_required + @is_admin_or_owner_required + @account_initialization_required + def patch(self, id: str): + return _update_endpoint(endpoint_id=id) + + +@console_ns.route("/workspaces/current/endpoints/delete") +class DeprecatedEndpointDeleteApi(Resource): + """Deprecated verb-based alias for endpoint deletion.""" + + @console_ns.doc("delete_endpoint_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc( + description=( + "Deprecated legacy alias for deleting a plugin endpoint. " + "Use DELETE /workspaces/current/endpoints/{id} instead." + ) + ) @console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @console_ns.response( 200, @@ -206,22 +329,23 @@ class EndpointDeleteApi(Resource): @is_admin_or_owner_required @account_initialization_required def post(self): - user, tenant_id = current_account_with_tenant() - args = EndpointIdPayload.model_validate(console_ns.payload) - - return { - "success": EndpointService.delete_endpoint( - tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id - ) - } + return _delete_endpoint(endpoint_id=args.endpoint_id) @console_ns.route("/workspaces/current/endpoints/update") -class EndpointUpdateApi(Resource): - @console_ns.doc("update_endpoint") - @console_ns.doc(description="Update a plugin endpoint") - @console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__]) +class DeprecatedEndpointUpdateApi(Resource): + """Deprecated verb-based alias for endpoint updates.""" + + @console_ns.doc("update_endpoint_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc( + description=( + "Deprecated legacy alias for updating a plugin endpoint. " + "Use PATCH /workspaces/current/endpoints/{id} instead." + ) + ) + @console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__]) @console_ns.response( 200, "Endpoint updated successfully", @@ -233,19 +357,8 @@ class EndpointUpdateApi(Resource): @is_admin_or_owner_required @account_initialization_required def post(self): - user, tenant_id = current_account_with_tenant() - - args = EndpointUpdatePayload.model_validate(console_ns.payload) - - return { - "success": EndpointService.update_endpoint( - tenant_id=tenant_id, - user_id=user.id, - endpoint_id=args.endpoint_id, - name=args.name, - settings=args.settings, - ) - } + args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload) + return _update_endpoint(endpoint_id=args.endpoint_id) @console_ns.route("/workspaces/current/endpoints/enable") diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index e4cfca9fa4..2a6f37aec8 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index cbb9677309..4b10561fdb 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 9182dbb510..b2d07ff8f9 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService @@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource): class ParserValidate(BaseModel): model: str model_type: ModelType - credentials: dict + credentials: dict[str, Any] console_ns.schema_model( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index aa674a63b3..b3e344ccea 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,7 +4,6 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -15,6 +14,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c9956501e2..34c9534de8 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,7 +5,6 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -28,6 +27,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -1131,6 +1131,14 @@ class ToolMCPAuthApi(Resource): with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) + parsed = urlparse(server_url) + sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}" + logger.warning( + "MCP authorization failed for provider %s (url=%s)", + provider_id, + sanitized_url, + exc_info=True, + ) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 7a28a09861..d11b66244f 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,7 +3,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden @@ -16,6 +15,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 42874e6033..565099db61 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,9 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import Unauthorized @@ -26,6 +27,7 @@ from controllers.console.wraps import ( ) from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from fields.base import ResponseModel from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantCustomConfigDict, TenantStatus @@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel): name: str +class TenantInfoResponse(ResponseModel): + id: str + name: str | None = None + plan: str | None = None + status: str | None = None + created_at: int | None = None + role: str | None = None + in_trial: bool | None = None + trial_end_reason: str | None = None + custom_config: dict | None = None + trial_credits: int | None = None + trial_credits_used: int | None = None + next_credit_reset_date: int | None = None + + @field_validator("plan", "status", "trial_end_reason", mode="before") + @classmethod + def _normalize_enum_like(cls, value): + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None): + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @@ -66,6 +99,7 @@ reg(WorkspaceListQuery) reg(SwitchWorkspacePayload) reg(WorkspaceCustomConfigPayload) reg(WorkspaceInfoPayload) +reg(TenantInfoResponse) provider_fields = { "provider_name": fields.String, @@ -180,7 +214,7 @@ class TenantApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tenant_fields) + @console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__]) def post(self): if request.path == "/info": logger.warning("Deprecated URL /info was used.") @@ -200,7 +234,13 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") - return WorkspaceService.get_tenant_info(tenant), 200 + return ( + TenantInfoResponse.model_validate( + WorkspaceService.get_tenant_info(tenant), + from_attributes=True, + ).model_dump(mode="json"), + 200, + ) @console_ns.route("/workspaces/switch") diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 6c15f9aa8b..915a11dcdd 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -9,7 +9,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model from controllers.console.wraps import setup_required @@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource): account.set_tenant_id(workspace_id) - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: dsl_service = AppDslService(session) result = dsl_service.import_app( account=account, @@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource): name=args.name, description=args.description, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 83c8fa02fe..72cab3de73 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,5 +1,4 @@ from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -30,6 +29,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index a5846e2815..2f309262cb 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel): def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ - Get current user + Get current user. NOTE: user_id is not trusted, it could be maliciously set to any value. - As a result, it could only be considered as an end user id. + As a result, it could only be considered as an end user id. Even when a + concrete end-user ID is supplied, lookups must stay tenant-scoped so one + tenant cannot bind another tenant's user record into the plugin request + context. """ if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID @@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: .limit(1) ) else: - user_model = session.get(EndUser, user_id) + user_model = session.scalar( + select(EndUser) + .where( + EndUser.id == user_id, + EndUser.tenant_id == tenant_id, + ) + .limit(1) + ) if not user_model: user_model = EndUser( diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 8066f198bb..f652bbc581 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,7 +2,6 @@ from typing import Any, Union from flask import Response from flask_restx import Resource -from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -12,6 +11,7 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 4f7f7d9a98..182631e8f5 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -23,9 +23,11 @@ from .app import ( conversation, file, file_preview, + human_input_form, message, site, workflow, + workflow_events, ) from .dataset import ( dataset, @@ -50,6 +52,7 @@ __all__ = [ "file", "file_preview", "hit_testing", + "human_input_form", "index", "message", "metadata", @@ -58,6 +61,7 @@ __all__ = [ "segment", "site", "workflow", + "workflow_events", ] api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 907dd1b06d..e818573b8f 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import services @@ -22,6 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 3142e5118e..31f2797d66 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,7 +4,6 @@ from uuid import UUID from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -29,6 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 1ec289e2a2..ca4b18cb5e 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Literal from flask import request @@ -14,14 +15,13 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from fields.conversation_fields import ( ConversationInfiniteScrollPagination, SimpleConversation, ) -from fields.conversation_variable_fields import ( - build_conversation_variable_infinite_scroll_pagination_model, - build_conversation_variable_model, -) +from graphon.variables.types import SegmentType from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel): value: Any +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type()) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + register_schema_models( service_api_ns, ConversationListQuery, ConversationRenamePayload, ConversationVariablesQuery, ConversationVariableUpdatePayload, + ConversationVariableResponse, + ConversationVariableInfiniteScrollPaginationResponse, ) @@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource): 404: "Conversation not found", } ) + @service_api_ns.response( + 200, + "Variables retrieved successfully", + service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser, c_id): """List all variables for a conversation. @@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource): last_id = str(query_args.last_id) if query_args.last_id else None try: - return ConversationService.get_conversational_variable( + pagination = ConversationService.get_conversational_variable( app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name ) + return ConversationVariableInfiniteScrollPaginationResponse.model_validate( + pagination, from_attributes=True + ).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource): 404: "Conversation or variable not found", } ) + @service_api_ns.response( + 200, + "Variable updated successfully", + service_api_ns.models[ConversationVariableResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) def put(self, app_model: App, end_user: EndUser, c_id, variable_id): """Update a conversation variable's value. @@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource): payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.update_conversation_variable( + variable = ConversationService.update_conversation_variable( app_model, conversation_id, variable_id, end_user, payload.value ) + return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationVariableNotExistsError: diff --git a/api/controllers/service_api/app/human_input_form.py b/api/controllers/service_api/app/human_input_form.py new file mode 100644 index 0000000000..8e5003dbbf --- /dev/null +++ b/api/controllers/service_api/app/human_input_form.py @@ -0,0 +1,137 @@ +""" +Service API human input form endpoints. + +This module exposes app-token authenticated APIs for fetching and submitting +paused human input forms in workflow/chatflow runs. +""" + +import json +import logging +from datetime import datetime + +from flask import Response +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.common.human_input import HumanInputFormSubmitPayload +from controllers.common.schema import register_schema_models +from controllers.service_api import service_api_ns +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface +from extensions.ext_database import db +from models.model import App, EndUser +from services.human_input_service import Form, FormNotFoundError, HumanInputService + +logger = logging.getLogger(__name__) + + +register_schema_models(service_api_ns, HumanInputFormSubmitPayload) + + +def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: + result: dict[str, str] = {} + for key, value in values.items(): + if value is None: + result[key] = "" + elif isinstance(value, (dict, list)): + result[key] = json.dumps(value, ensure_ascii=False) + else: + result[key] = str(value) + return result + + +def _to_timestamp(value: datetime) -> int: + return int(value.timestamp()) + + +def _jsonify_form_definition(form: Form) -> Response: + definition_payload = form.get_definition().model_dump() + payload = { + "form_content": definition_payload["rendered_content"], + "inputs": definition_payload["inputs"], + "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), + "user_actions": definition_payload["user_actions"], + "expiration_time": _to_timestamp(form.expiration_time), + } + return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") + + +def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None: + if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id: + raise NotFound("Form not found") + + +def _ensure_form_is_allowed_for_service_api(form: Form) -> None: + # Keep app-token callers scoped to the public web-form surface; internal HITL + # routes must continue to flow through console-only authentication. + if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API): + raise NotFound("Form not found") + + +@service_api_ns.route("/form/human_input/") +class WorkflowHumanInputFormApi(Resource): + @service_api_ns.doc("get_human_input_form") + @service_api_ns.doc(description="Get a paused human input form by token") + @service_api_ns.doc(params={"form_token": "Human input form token"}) + @service_api_ns.doc( + responses={ + 200: "Form retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Form not found", + 412: "Form already submitted or expired", + } + ) + @validate_app_token + def get(self, app_model: App, form_token: str): + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFound("Form not found") + + _ensure_form_belongs_to_app(form, app_model) + _ensure_form_is_allowed_for_service_api(form) + service.ensure_form_active(form) + return _jsonify_form_definition(form) + + @service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__]) + @service_api_ns.doc("submit_human_input_form") + @service_api_ns.doc(description="Submit a paused human input form by token") + @service_api_ns.doc(params={"form_token": "Human input form token"}) + @service_api_ns.doc( + responses={ + 200: "Form submitted successfully", + 400: "Bad request - invalid submission data", + 401: "Unauthorized - invalid API token", + 404: "Form not found", + 412: "Form already submitted or expired", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, form_token: str): + payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {}) + + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFound("Form not found") + + _ensure_form_belongs_to_app(form, app_model) + _ensure_form_is_allowed_for_service_api(form) + + recipient_type = form.recipient_type + if recipient_type is None: + logger.warning("Recipient type is None for form, form_id=%s", form.id) + raise BadRequest("Form recipient type is invalid") + + try: + service.submit_form_by_token( + recipient_type=recipient_type, + form_token=form_token, + selected_action_id=payload.action, + form_data=payload.inputs, + submission_end_user_id=end_user.id, + ) + except FormNotFoundError: + raise NotFound("Form not found") + + return {}, 200 diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index e0a64ffe26..cc763fa89c 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,13 +1,12 @@ import logging +from collections.abc import Mapping +from datetime import datetime from typing import Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Namespace, Resource, fields -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -33,9 +32,13 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +def _enum_value(value): + return getattr(value, "value", value) + + class WorkflowRunStatusField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - return obj.status.value + return _enum_value(obj.status) class WorkflowRunOutputsField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - if obj.status == WorkflowExecutionStatus.PAUSED: + status = _enum_value(obj.status) + if status == WorkflowExecutionStatus.PAUSED.value: return {} outputs = obj.outputs_dict return outputs or {} -workflow_run_fields = { - "id": fields.String, - "workflow_id": fields.String, - "status": WorkflowRunStatusField, - "inputs": fields.Raw, - "outputs": WorkflowRunOutputsField, - "error": fields.String, - "total_steps": fields.Integer, - "total_tokens": fields.Integer, - "created_at": TimestampField, - "finished_at": OptionalTimestampField, - "elapsed_time": fields.Float, -} +class WorkflowRunResponse(ResponseModel): + id: str + workflow_id: str + status: str + inputs: dict | list | str | int | float | bool | None = None + outputs: dict = Field(default_factory=dict) + error: str | None = None + total_steps: int | None = None + total_tokens: int | None = None + created_at: int | None = None + finished_at: int | None = None + elapsed_time: float | int | None = None + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_workflow_run_model(api_or_ns: Namespace): - """Build the workflow run model for the API or Namespace.""" - return api_or_ns.model("WorkflowRun", workflow_run_fields) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | int | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", "triggered_from", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: dict | list | str | int | float | bool | None = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_from", "created_by_role", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +register_schema_models( + service_api_ns, + WorkflowRunResponse, + WorkflowRunForLogResponse, + WorkflowAppLogPartialResponse, + WorkflowAppLogPaginationResponse, +) + + +def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict: + status = _enum_value(workflow_run.status) + raw_outputs = workflow_run.outputs_dict + if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None: + outputs: dict = {} + elif isinstance(raw_outputs, dict): + outputs = raw_outputs + elif isinstance(raw_outputs, Mapping): + outputs = dict(raw_outputs) + else: + outputs = {} + return WorkflowRunResponse.model_validate( + { + "id": workflow_run.id, + "workflow_id": workflow_run.workflow_id, + "status": status, + "inputs": workflow_run.inputs, + "outputs": outputs, + "error": workflow_run.error, + "total_steps": workflow_run.total_steps, + "total_tokens": workflow_run.total_tokens, + "created_at": workflow_run.created_at, + "finished_at": workflow_run.finished_at, + "elapsed_time": workflow_run.elapsed_time, + } + ).model_dump(mode="json") + + +def _serialize_workflow_log_pagination(pagination) -> dict: + return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json") @service_api_ns.route("/workflows/run/") @@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) + @service_api_ns.response( + 200, + "Workflow run details retrieved successfully", + service_api_ns.models[WorkflowRunResponse.__name__], + ) def get(self, app_model: App, workflow_run_id: str): """Get a workflow task running detail. @@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource): ) if not workflow_run: raise NotFound("Workflow run not found.") - return workflow_run + return _serialize_workflow_run(workflow_run) @service_api_ns.route("/workflows/run") @@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) + @service_api_ns.response( + 200, + "Logs retrieved successfully", + service_api_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) def get(self, app_model: App): """Get workflow app logs. @@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return _serialize_workflow_log_pagination(workflow_app_log_pagination) diff --git a/api/controllers/service_api/app/workflow_events.py b/api/controllers/service_api/app/workflow_events.py new file mode 100644 index 0000000000..b281b271c0 --- /dev/null +++ b/api/controllers/service_api/app/workflow_events.py @@ -0,0 +1,142 @@ +""" +Service API workflow resume event stream endpoints. +""" + +import json +from collections.abc import Generator + +from flask import Response, request +from flask_restx import Resource +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from controllers.service_api import service_api_ns +from controllers.service_api.app.error import NotWorkflowAppError +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.message_generator import MessageGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.task_entities import StreamEvent +from core.workflow.human_input_policy import HumanInputSurface +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from repositories.factory import DifyAPIRepositoryFactory +from services.workflow_event_snapshot_service import build_workflow_event_stream + + +@service_api_ns.route("/workflow//events") +class WorkflowEventsApi(Resource): + """Service API for getting workflow execution events after resume.""" + + @service_api_ns.doc("get_workflow_events") + @service_api_ns.doc(description="Get workflow execution events stream after resume") + @service_api_ns.doc( + params={ + "task_id": "Workflow run ID", + "user": "End user identifier (query param)", + "include_state_snapshot": ( + "Whether to replay from persisted state snapshot, " + 'specify `"true"` to include a status snapshot of executed nodes' + ), + "continue_on_pause": ( + "Whether to keep the stream open across workflow_paused events," + 'specify `"true"` to keep the stream open for `workflow_paused` events.' + ), + } + ) + @service_api_ns.doc( + responses={ + 200: "SSE event stream", + 401: "Unauthorized - invalid API token", + 404: "Workflow run not found", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) + def get(self, app_model: App, end_user: EndUser, task_id: str): + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: + raise NotWorkflowAppError() + + session_maker = sessionmaker(db.engine) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + workflow_run = repo.get_workflow_run_by_id_and_tenant_id( + tenant_id=app_model.tenant_id, + run_id=task_id, + ) + + if workflow_run is None: + raise NotFound("Workflow run not found") + + if workflow_run.app_id != app_model.id: + raise NotFound("Workflow run not found") + + if workflow_run.created_by_role != CreatorUserRole.END_USER: + raise NotFound("Workflow run not found") + + if workflow_run.created_by != end_user.id: + raise NotFound("Workflow run not found") + + workflow_run_entity = workflow_run + + if workflow_run_entity.finished_at is not None: + response = WorkflowResponseConverter.workflow_run_result_to_finish_response( + task_id=workflow_run_entity.id, + workflow_run=workflow_run_entity, + creator_user=end_user, + ) + + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + + def _generate_finished_events() -> Generator[str, None, None]: + yield f"data: {json.dumps(payload)}\n\n" + + event_generator = _generate_finished_events + else: + msg_generator = MessageGenerator() + generator: BaseAppGenerator + if app_mode == AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + elif app_mode == AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + else: + raise NotWorkflowAppError() + + include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" + continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true" + terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None + + def _generate_stream_events(): + if include_state_snapshot: + return generator.convert_to_event_stream( + build_workflow_event_stream( + app_mode=app_mode, + workflow_run=workflow_run_entity, + tenant_id=app_model.tenant_id, + app_id=app_model.id, + session_maker=session_maker, + human_input_surface=HumanInputSurface.SERVICE_API, + close_on_pause=not continue_on_pause, + ) + ) + return generator.convert_to_event_stream( + msg_generator.retrieve_events( + app_mode, + workflow_run_entity.id, + terminal_events=terminal_events, + ), + ) + + event_generator = _generate_stream_events + + return Response( + event_generator(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index fd954be6b1..76519cad0a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,7 +2,6 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -19,6 +18,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 6db047567f..bc28ecb6b7 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,4 +1,12 @@ +"""Service API endpoints for dataset document management. + +The canonical Service API paths use hyphenated route segments. Legacy underscore +aliases remain registered for backward compatibility, but they must stay marked +deprecated in generated API docs so clients migrate toward the canonical paths. +""" + import json +from collections.abc import Mapping from contextlib import ExitStack from typing import Self from uuid import UUID @@ -117,12 +125,137 @@ register_schema_models( ) -@service_api_ns.route( - "/datasets//document/create_by_text", - "/datasets//document/create-by-text", -) +def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]: + """Create a document from text for both canonical and legacy routes.""" + payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) + + dataset_id_str = str(dataset_id) + tenant_id_str = str(tenant_id) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1) + ) + + if not dataset: + raise ValueError("Dataset does not exist.") + + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") + + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model) + + retrieval_model = payload.retrieval_model + if ( + retrieval_model + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name + ): + DatasetService.check_reranking_model_setting( + tenant_id_str, + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, + ) + + if not current_user: + raise ValueError("current_user is required") + + upload_file = FileService(db.engine).upload_text( + text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str + ) + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + if not current_user: + raise ValueError("current_user is required") + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} + return documents_and_batch_fields, 200 + + +def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]: + """Update a document from text for both canonical and legacy routes.""" + payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) + args = payload.model_dump(exclude_none=True) + if not dataset: + raise ValueError("Dataset does not exist.") + + retrieval_model = payload.retrieval_model + if ( + retrieval_model + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name + ): + DatasetService.check_reranking_model_setting( + tenant_id, + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, + ) + + # indexing_technique is already set in dataset since this is an update + args["indexing_technique"] = dataset.indexing_technique + + if args.get("text"): + text = args.get("text") + name = args.get("name") + if not current_user: + raise ValueError("current_user is required") + upload_file = FileService(db.engine).upload_text( + text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + ) + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + + args["original_document_id"] = str(document_id) + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} + return documents_and_batch_fields, 200 + + +@service_api_ns.route("/datasets//document/create-by-text") class DocumentAddByTextApi(DatasetApiResource): - """Resource for documents.""" + """Resource for the canonical text document creation route.""" @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) @service_api_ns.doc("create_document_by_text") @@ -138,81 +271,43 @@ class DocumentAddByTextApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id: str, dataset_id: UUID): """Create document by text.""" - payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) - args = payload.model_dump(exclude_none=True) + return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id) - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + +@service_api_ns.route("/datasets//document/create_by_text") +class DeprecatedDocumentAddByTextApi(DatasetApiResource): + """Deprecated resource alias for text document creation.""" + + @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) + @service_api_ns.doc("create_document_by_text_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for creating a new document by providing text content. " + "Use /datasets/{dataset_id}/document/create-by-text instead." ) - - if not dataset: - raise ValueError("Dataset does not exist.") - - if not dataset.indexing_technique and not args["indexing_technique"]: - raise ValueError("indexing_technique is required.") - - embedding_model_provider = payload.embedding_model_provider - embedding_model = payload.embedding_model - if embedding_model_provider and embedding_model: - DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) - - retrieval_model = payload.retrieval_model - if ( - retrieval_model - and retrieval_model.reranking_model - and retrieval_model.reranking_model.reranking_provider_name - and retrieval_model.reranking_model.reranking_model_name - ): - DatasetService.check_reranking_model_setting( - tenant_id, - retrieval_model.reranking_model.reranking_provider_name, - retrieval_model.reranking_model.reranking_model_name, - ) - - if not current_user: - raise ValueError("current_user is required") - - upload_file = FileService(db.engine).upload_text( - text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id - ) - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + ) + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", } - args["data_source"] = data_source - knowledge_config = KnowledgeConfig.model_validate(args) - # validate args - DocumentService.document_create_args_validate(knowledge_config) - - if not current_user: - raise ValueError("current_user is required") - - try: - documents, batch = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=current_user, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} - return documents_and_batch_fields, 200 + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id: str, dataset_id: UUID): + """Create document by text through the deprecated underscore alias.""" + return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id) -@service_api_ns.route( - "/datasets//documents//update_by_text", - "/datasets//documents//update-by-text", -) +@service_api_ns.route("/datasets//documents//update-by-text") class DocumentUpdateByTextApi(DatasetApiResource): - """Resource for update documents.""" + """Resource for the canonical text document update route.""" @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) @service_api_ns.doc("update_document_by_text") @@ -229,62 +324,35 @@ class DocumentUpdateByTextApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" - payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) + + +@service_api_ns.route("/datasets//documents//update_by_text") +class DeprecatedDocumentUpdateByTextApi(DatasetApiResource): + """Deprecated resource alias for text document updates.""" + + @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) + @service_api_ns.doc("update_document_by_text_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for updating an existing document by providing text content. " + "Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead." ) - args = payload.model_dump(exclude_none=True) - if not dataset: - raise ValueError("Dataset does not exist.") - - retrieval_model = payload.retrieval_model - if ( - retrieval_model - and retrieval_model.reranking_model - and retrieval_model.reranking_model.reranking_provider_name - and retrieval_model.reranking_model.reranking_model_name - ): - DatasetService.check_reranking_model_setting( - tenant_id, - retrieval_model.reranking_model.reranking_provider_name, - retrieval_model.reranking_model.reranking_model_name, - ) - - # indexing_technique is already set in dataset since this is an update - args["indexing_technique"] = dataset.indexing_technique - - if args.get("text"): - text = args.get("text") - name = args.get("name") - if not current_user: - raise ValueError("current_user is required") - upload_file = FileService(db.engine).upload_text( - text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id - ) - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, - } - args["data_source"] = data_source - # validate args - args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig.model_validate(args) - DocumentService.document_create_args_validate(knowledge_config) - - try: - documents, batch = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=current_user, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} - return documents_and_batch_fields, 200 + ) + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): + """Update document by text through the deprecated underscore alias.""" + return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) @service_api_ns.route( diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 971b63577c..5992fa7410 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,7 +2,6 @@ from typing import Any from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound @@ -23,6 +22,7 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index c0a6cb0a76..5ac65fc4e6 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token +from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0ef4471018..3ad595f1f4 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import fields, marshal_with -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import field_validator from werkzeug.exceptions import InternalServerError @@ -22,6 +21,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e37f9af5f0..0528184d79 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,6 @@ import logging from typing import Any, Literal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 44876f8303..1ddf2e0717 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict from flask import Response, request from flask_restx import Resource -from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.web import web_ns from controllers.web.error import NotFoundError, WebFormRateLimitExceededError from controllers.web.site import serialize_app_site_payload @@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ logger = logging.getLogger(__name__) -class HumanInputFormSubmitPayload(BaseModel): - inputs: dict - action: str - - _FORM_SUBMIT_RATE_LIMITER = RateLimiter( prefix="web_form_submit_rate_limit", max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 39afdd843f..07ecf8035b 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,7 +2,6 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.enums import FeedbackRating from models.model import AppMode diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 38aeccc642..fe31e9d4ac 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -14,6 +13,7 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 796e090976..98211193a0 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,7 +1,5 @@ import logging -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -24,6 +22,8 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 06c746990d..c22102c2ba 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,20 +4,6 @@ import uuid from decimal import Decimal from typing import Union, cast -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity @@ -43,6 +29,20 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index f07ac64498..0bc93ad34d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,15 +4,6 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any, TypedDict -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) - from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -24,6 +15,14 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 2b2e26987e..a2186be100 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index d4c52a8eb1..51a30998ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -8,8 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index fdffde85d0..29de0b8b1c 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,6 +4,13 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,14 +26,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes - -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) @@ -300,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # update prompt tool for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + tool_instance = tool_instances.get(prompt_tool.name) + if tool_instance: + self.update_prompt_message_tool(tool_instance, prompt_tool) iteration_step += 1 diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 46c1f1230d..f341ca5a0b 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -1,17 +1,16 @@ import json import re from collections.abc import Generator -from typing import Union - -from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from typing import Any, Union from core.agent.entities import AgentScratchpadUnit +from graphon.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod def handle_react_stream_output( - cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any] ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]: action_name = None diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 90aa7b5fd4..8d25863a91 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel): identity: AgentStrategyIdentity parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None features: list[AgentFeature] | None = None meta_version: str | None = None # pydantic configs diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 7d1b11c008..c8ec7cb44d 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id: str, config: dict, only_structure_validate: bool = False - ) -> tuple[dict, list[str]]: + cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False + ) -> tuple[dict[str, Any], list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index b7dd55632e..5df3df2b3e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,14 +1,13 @@ from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 981bd26961..02498c23e1 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,9 @@ from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -41,7 +40,7 @@ class ModelConfigManager: ) @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for model config diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 57c6d1c496..4c07445df3 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,7 +1,5 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessageRole - from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -9,6 +7,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index c89e1b3c3d..ddb500cccf 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,10 +1,9 @@ import re from typing import Any, cast -from graphon.variables.input_entities import VariableEntity, VariableEntityType - from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 819aca864c..53563dc5da 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,14 +1,14 @@ from enum import StrEnum, auto from typing import Any, Literal -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.entities import MetadataFilteringCondition +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 959c3868b4..8f20ef2ff9 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,8 @@ from collections.abc import Mapping from typing import Any -from graphon.file import FileUploadConfig - from constants import DEFAULT_FILE_NUMBER_LIMITS +from graphon.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index 2dddce349c..0c36992c77 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -1,5 +1,7 @@ from typing import Any +CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH = 1000 + class SuggestedQuestionsAfterAnswerConfigManager: @classmethod @@ -20,7 +22,11 @@ class SuggestedQuestionsAfterAnswerConfigManager: @classmethod def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ - Validate and set defaults for suggested questions feature + Validate and set defaults for suggested questions feature. + + Optional fields: + - prompt: custom instruction prompt. + - model: provider/model configuration for suggested question generation. :param config: app model config args """ @@ -39,4 +45,27 @@ class SuggestedQuestionsAfterAnswerConfigManager: if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") + prompt = config["suggested_questions_after_answer"].get("prompt") + if prompt is not None and not isinstance(prompt, str): + raise ValueError("prompt in suggested_questions_after_answer must be of string type") + if isinstance(prompt, str) and len(prompt) > CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH: + raise ValueError( + f"prompt in suggested_questions_after_answer must be less than or equal to " + f"{CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH} characters" + ) + + if "model" in config["suggested_questions_after_answer"]: + model_config = config["suggested_questions_after_answer"]["model"] + if not isinstance(model_config, dict): + raise ValueError("model in suggested_questions_after_answer must be of object type") + + if "provider" not in model_config or not isinstance(model_config["provider"], str): + raise ValueError("provider in suggested_questions_after_answer.model must be of string type") + + if "name" not in model_config or not isinstance(model_config["name"], str): + raise ValueError("name in suggested_questions_after_answer.model must be of string type") + + if "completion_params" in model_config and not isinstance(model_config["completion_params"], dict): + raise ValueError("completion_params in suggested_questions_after_answer.model must be of object type") + return config, ["suggested_questions_after_answer"] diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 62e0c31d1a..13ace32fd6 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,8 +1,7 @@ import re -from graphon.variables.input_entities import VariableEntity - from core.app.app_config.entities import RagPipelineVariableEntity +from graphon.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 985ded0f74..b79d5514b4 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -18,11 +18,6 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader - from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -39,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse +from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, +) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager @@ -48,6 +47,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom @@ -656,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]: + ) -> ( + ChatbotAppBlockingResponse + | AdvancedChatPausedBlockingResponse + | Generator[ChatbotAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 7b4cb98bd4..4e57b4dedc 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,12 +3,6 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -43,6 +37,12 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 5c9bc43992..7cb0c9a8d3 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -3,7 +3,7 @@ from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( - AppBlockingResponse, + AdvancedChatPausedBlockingResponse, AppStreamResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, @@ -12,22 +12,40 @@ from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeStartStreamResponse, PingStreamResponse, + StreamEvent, ) -class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class AdvancedChatAppGenerateResponseConverter( + AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse] +): @classmethod - def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_full_response( + cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking full response. :param blocking_response: blocking response :return: """ - blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) + if isinstance(blocking_response, AdvancedChatPausedBlockingResponse): + paused_data = blocking_response.data.model_dump(mode="json") + return { + "event": StreamEvent.WORKFLOW_PAUSED.value, + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, + "workflow_run_id": blocking_response.data.workflow_run_id, + "data": paused_data, + } + response = { - "event": "message", + "event": StreamEvent.MESSAGE.value, "task_id": blocking_response.task_id, "id": blocking_response.data.id, "message_id": blocking_response.data.message_id, @@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_simple_response( + cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -50,14 +70,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -88,7 +109,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 0ce9ddce9e..82dbf5381d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -9,12 +9,6 @@ from datetime import datetime from threading import Thread from typing import Any, Union -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -59,14 +53,18 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, + HumanInputRequiredPauseReasonPayload, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, StreamResponse, + WorkflowPauseStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -77,6 +75,12 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus @@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer - def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def process( + self, + ) -> Union[ + ChatbotAppBlockingResponse, + AdvancedChatPausedBlockingResponse, + Generator[ChatbotAppStreamResponse, None, None], + ]: """ Process generate task pipeline. :return: @@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]: """ Process blocking response. :return: """ + human_input_responses: list[HumanInputRequiredResponse] = [] for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err + elif isinstance(stream_response, HumanInputRequiredResponse): + human_input_responses.append(stream_response) + elif isinstance(stream_response, WorkflowPauseStreamResponse): + return AdvancedChatPausedBlockingResponse( + task_id=stream_response.task_id, + data=AdvancedChatPausedBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + workflow_run_id=stream_response.data.workflow_run_id, + answer=self._task_state.answer, + metadata=self._message_end_to_stream_response().metadata, + created_at=self._message_created_at, + paused_nodes=stream_response.data.paused_nodes, + reasons=stream_response.data.reasons, + status=stream_response.data.status, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + ), + ) elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: @@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: continue + if human_input_responses: + return self._build_paused_blocking_response_from_human_input(human_input_responses) + raise ValueError("queue listening stopped unexpectedly.") + def _build_paused_blocking_response_from_human_input( + self, human_input_responses: list[HumanInputRequiredResponse] + ) -> AdvancedChatPausedBlockingResponse: + runtime_state = self._resolve_graph_runtime_state() + paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses)) + reasons = [ + HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json") + for response in human_input_responses + ] + + return AdvancedChatPausedBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=AdvancedChatPausedBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + workflow_run_id=human_input_responses[-1].workflow_run_id, + answer=self._task_state.answer, + metadata=self._message_end_to_stream_response().metadata, + created_at=self._message_created_at, + paused_nodes=paused_nodes, + reasons=reasons, + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at, + total_tokens=runtime_state.total_tokens, + total_steps=runtime_state.node_run_steps, + ), + ) + def _to_stream_response( self, generator: Generator[StreamResponse, None, None] ) -> Generator[ChatbotAppStreamResponse, Any, None]: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 5872f6b264..5cdc477028 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a20d3f3c38..cae0eee0df 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,9 +1,6 @@ import logging from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -19,6 +16,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 0c146c388f..03bc0a9108 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast + +from pydantic import JsonValue from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -12,11 +14,9 @@ from core.app.entities.task_entities import ( ) -class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 6e5a86505c..abcbb2f943 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -1,19 +1,22 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Union, cast -from graphon.model_runtime.errors.invoke import InvokeError +from pydantic import JsonValue from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) -class AppGenerateResponseConverter(ABC): - _blocking_response_type: type[AppBlockingResponse] +class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC): + @classmethod + def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse: + return cast(TBlockingResponse, response) @classmethod def convert( @@ -21,45 +24,45 @@ class AppGenerateResponseConverter(ABC): ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): - return cls.convert_blocking_full_response(response) + return cls.convert_blocking_full_response(cls._cast_blocking_response(response)) else: - def _generate_full_response() -> Generator[dict | str, Any, None]: + def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_full_response(response) return _generate_full_response() else: if isinstance(response, AppBlockingResponse): - return cls.convert_blocking_simple_response(response) + return cls.convert_blocking_simple_response(cls._cast_blocking_response(response)) else: - def _generate_simple_response() -> Generator[dict | str, Any, None]: + def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_simple_response(response) return _generate_simple_response() @classmethod @abstractmethod - def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @abstractmethod - def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @abstractmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod @abstractmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod @@ -107,13 +110,13 @@ class AppGenerateResponseConverter(ABC): return metadata @classmethod - def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]: + def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]: """ Error to stream response. :param e: exception :return: """ - error_responses: dict[type[Exception], dict[str, Any]] = { + error_responses: dict[type[Exception], dict[str, JsonValue]] = { ValueError: {"code": "invalid_param", "status": 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { @@ -127,7 +130,7 @@ class AppGenerateResponseConverter(ABC): } # Determine the response based on the type of exception - data: dict[str, Any] | None = None + data: dict[str, JsonValue] | None = None for k, v in error_responses.items(): if isinstance(e, k): data = v diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 7eccd59d17..8e8ccf2b90 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,9 +2,6 @@ from collections.abc import Generator, Mapping, Sequence from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final -from graphon.enums import NodeType -from graphon.file import File, FileUploadConfig -from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session from core.app.apps.draft_variable_saver import ( @@ -16,6 +13,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from extensions.ext_database import db from factories import file_factory +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 20bf81aeec..d1771452c5 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,7 +7,6 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod -from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -22,6 +21,7 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client +from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4aebc0cb30..1251b397e2 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,17 +5,6 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError - from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -41,6 +30,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 891dcece73..58afefe296 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 050f763e95..077c5239f3 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -18,6 +16,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index f23ee7f89f..26efcbfafd 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast + +from pydantic import JsonValue from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -12,11 +14,9 @@ from core.app.entities.task_entities import ( ) -class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index ab277857fe..2a90fbdad0 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from graphon.runtime import GraphRuntimeState - from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index a515531616..7bab3f7bff 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,19 +6,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -65,9 +52,23 @@ from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm @@ -336,7 +337,26 @@ class WorkflowResponseConverter: except (TypeError, json.JSONDecodeError): definition_payload = {} display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) - form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) + form_token_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=( + HumanInputSurface.SERVICE_API + if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API + else None + ), + ) + + # Reconnect paths must preserve the same pause-reason contract as live streams; + # otherwise clients see schema drift after resume. + pause_reasons = enrich_human_input_pause_reasons( + pause_reasons, + form_tokens_by_form_id=form_token_by_form_id, + expiration_times_by_form_id={ + form_id: int(expiration_time.timestamp()) + for form_id, expiration_time in expiration_times_by_form_id.items() + }, + ) responses: list[StreamResponse] = [] diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 61339b316a..423bfdac51 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index b216f7cf7b..6bb1ecdcb1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -16,6 +14,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index a4f574642d..ad978f58e0 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast + +from pydantic import JsonValue from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -12,17 +14,15 @@ from core.app.entities.task_entities import ( ) -class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = CompletionAppBlockingResponse - +class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response :return: """ - response = { + response: dict[str, Any] = { "event": "message", "task_id": blocking_response.task_id, "id": blocking_response.data.id, @@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "message_id": chunk.message_id, "created_at": chunk.created_at, @@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "message_id": chunk.message_id, "created_at": chunk.created_at, diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py index 68631bb230..c04f20c796 100644 --- a/api/core/app/apps/message_generator.py +++ b/api/core/app/apps/message_generator.py @@ -1,6 +1,7 @@ -from collections.abc import Callable, Generator, Mapping +from collections.abc import Callable, Generator, Iterable, Mapping from core.app.apps.streaming_utils import stream_topic_events +from core.app.entities.task_entities import StreamEvent from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic from models.model import AppMode @@ -26,6 +27,7 @@ class MessageGenerator: idle_timeout=300, ping_interval: float = 10.0, on_subscribe: Callable[[], None] | None = None, + terminal_events: Iterable[str | StreamEvent] | None = None, ) -> Generator[Mapping | str, None, None]: topic = cls.get_response_topic(app_mode, workflow_run_id) return stream_topic_events( @@ -33,4 +35,5 @@ class MessageGenerator: idle_timeout=idle_timeout, ping_interval=ping_interval, on_subscribe=on_subscribe, + terminal_events=terminal_events, ) diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index a77a978946..3913657ae8 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -13,11 +13,9 @@ from core.app.entities.task_entities import ( ) -class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = WorkflowAppBlockingResponse - +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: """ Convert blocking full response. :param blocking_response: blocking response @@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): return dict(blocking_response.model_dump()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -37,7 +35,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +64,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 83c74b86e5..4a76d0809e 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,8 +10,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, cast, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -29,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity -from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.entities.task_entities import ( + WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, + WorkflowAppStreamResponse, +) from core.datasource.entities.datasource_entities import ( DatasourceProviderType, OnlineDriveBrowseFilesRequest, @@ -43,6 +45,8 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline @@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: + ) -> ( + WorkflowAppBlockingResponse + | WorkflowAppPausedBlockingResponse + | Generator[WorkflowAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 36daaf09e9..2ee0ae27eb 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,12 +2,6 @@ import logging import time from typing import cast -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -26,6 +20,12 @@ from core.workflow.system_variables import build_bootstrap_variables, build_syst from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow diff --git a/api/core/app/apps/streaming_utils.py b/api/core/app/apps/streaming_utils.py index af3441aca3..5743bad4b6 100644 --- a/api/core/app/apps/streaming_utils.py +++ b/api/core/app/apps/streaming_utils.py @@ -59,7 +59,7 @@ def stream_topic_events( def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]: - if not terminal_events: + if terminal_events is None: return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value} values: set[str] = set() for item in terminal_events: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6074e81d1e..e811c2b2e0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,10 +8,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, overload from flask import Flask, current_app -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -29,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.entities.task_entities import ( + WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, + WorkflowAppStreamResponse, +) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args @@ -38,6 +38,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom @@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: + ) -> ( + WorkflowAppBlockingResponse + | WorkflowAppPausedBlockingResponse + | Generator[WorkflowAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 2cb8088971..cfb9208486 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,12 +3,6 @@ import time from collections.abc import Sequence from typing import cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -21,6 +15,11 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index c64f44a603..4037388798 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -9,24 +9,29 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, WorkflowAppStreamResponse, ) -class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = WorkflowAppBlockingResponse - +class WorkflowAppGenerateResponseConverter( + AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse] +): @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response( + cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.model_dump() + return dict(blocking_response.model_dump()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response( + cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -37,7 +42,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +71,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index f1b8b08eaa..87d9b73078 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,9 +4,6 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -45,12 +42,15 @@ from core.app.entities.queue_entities import ( ) from core.app.entities.task_entities import ( ErrorStreamResponse, + HumanInputRequiredPauseReasonPayload, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, StreamResponse, TextChunkStreamResponse, WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, WorkflowPauseStreamResponse, @@ -61,6 +61,9 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser @@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state - def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + def process( + self, + ) -> Union[ + WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None] + ]: """ Process generate task pipeline. :return: @@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]: """ To blocking response. :return: """ + human_input_responses: list[HumanInputRequiredResponse] = [] for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err + elif isinstance(stream_response, HumanInputRequiredResponse): + human_input_responses.append(stream_response) elif isinstance(stream_response, WorkflowPauseStreamResponse): - response = WorkflowAppBlockingResponse( + return WorkflowAppPausedBlockingResponse( task_id=self._application_generate_entity.task_id, workflow_run_id=stream_response.data.workflow_run_id, - data=WorkflowAppBlockingResponse.Data( + data=WorkflowAppPausedBlockingResponse.Data( id=stream_response.data.workflow_run_id, workflow_id=self._workflow.id, status=stream_response.data.status, @@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): total_steps=stream_response.data.total_steps, created_at=stream_response.data.created_at, finished_at=None, + paused_nodes=stream_response.data.paused_nodes, + reasons=stream_response.data.reasons, ), ) - return response elif isinstance(stream_response, WorkflowFinishStreamResponse): - response = WorkflowAppBlockingResponse( + return WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( @@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ), ) - return response else: continue + if human_input_responses: + return self._build_paused_blocking_response_from_human_input(human_input_responses) + raise ValueError("queue listening stopped unexpectedly.") + def _build_paused_blocking_response_from_human_input( + self, human_input_responses: list[HumanInputRequiredResponse] + ) -> WorkflowAppPausedBlockingResponse: + runtime_state = self._resolve_graph_runtime_state() + paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses)) + created_at = int(runtime_state.start_at) + reasons = [ + HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json") + for response in human_input_responses + ] + + return WorkflowAppPausedBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=human_input_responses[-1].workflow_run_id, + data=WorkflowAppPausedBlockingResponse.Data( + id=human_input_responses[-1].workflow_run_id, + workflow_id=self._workflow.id, + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + error=None, + elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at, + total_tokens=runtime_state.total_tokens, + total_steps=runtime_state.node_run_steps, + created_at=created_at, + finished_at=None, + paused_nodes=paused_nodes, + reasons=reasons, + ), + ) + def _to_stream_response( self, generator: Generator[StreamResponse, None, None] ) -> Generator[WorkflowAppStreamResponse, None, None]: @@ -682,15 +727,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None): invoke_from = self._application_generate_entity.invoke_from - if invoke_from == InvokeFrom.SERVICE_API: - created_from = WorkflowAppLogCreatedFrom.SERVICE_API - elif invoke_from == InvokeFrom.EXPLORE: - created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP - elif invoke_from == InvokeFrom.WEB_APP: - created_from = WorkflowAppLogCreatedFrom.WEB_APP - else: - # not save log for debugging - return + match invoke_from: + case InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + case InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + case InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION: + # not save log for debugging + return if not workflow_run_id: return diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 437432611d..047b54c86c 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,39 +3,6 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph import Graph -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -82,6 +49,39 @@ from core.workflow.system_variables import ( from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index a3fb7b4c5d..09992f4bbf 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,13 +2,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 482f995d8e..221b7fb058 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import PauseReason -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities import RetrievalSourceMetadata +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 88faf235d1..ad05566521 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,15 +1,16 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any +from typing import Any, Literal -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.nodes.human_input.entities import FormInput, UserAction -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, JsonValue from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities import RetrievalSourceMetadata +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): @@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse): data: Data +class HumanInputRequiredPauseReasonPayload(BaseModel): + """ + Public pause-reason payload used by blocking responses when only + ``human_input_required`` events are available. + """ + + TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED + form_id: str + node_id: str + node_title: str + form_content: str + inputs: Sequence[FormInput] = Field(default_factory=list) + actions: Sequence[UserAction] = Field(default_factory=list) + display_in_ui: bool = False + form_token: str | None = None + resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) + expiration_time: int + + @classmethod + def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload": + return cls( + form_id=data.form_id, + node_id=data.node_id, + node_title=data.node_title, + form_content=data.form_content, + inputs=data.inputs, + actions=data.actions, + display_in_ui=data.display_in_ui, + form_token=data.form_token, + resolved_default_values=data.resolved_default_values, + expiration_time=data.expiration_time, + ) + + class HumanInputFormFilledResponse(StreamResponse): class Data(BaseModel): """ @@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse): workflow_run_id: str data: Data - def to_ignore_detail_dict(self): + def to_ignore_detail_dict(self) -> dict[str, JsonValue]: return { "event": self.event.value, "task_id": self.task_id, @@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse): workflow_run_id: str data: Data - def to_ignore_detail_dict(self): + def to_ignore_detail_dict(self) -> dict[str, JsonValue]: return { "event": self.event.value, "task_id": self.task_id, @@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): data: Data +class AdvancedChatPausedBlockingResponse(AppBlockingResponse): + """ + ChatbotAppPausedBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + mode: str + conversation_id: str + message_id: str + workflow_run_id: str + answer: str + metadata: Mapping[str, object] = Field(default_factory=dict) + created_at: int + paused_nodes: Sequence[str] = Field(default_factory=list) + reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]]) + status: WorkflowExecutionStatus + elapsed_time: float + total_tokens: int + total_steps: int + + data: Data + + class CompletionAppBlockingResponse(AppBlockingResponse): """ CompletionAppBlockingResponse entity @@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): data: Data +class WorkflowAppPausedBlockingResponse(AppBlockingResponse): + """ + WorkflowAppPausedBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + workflow_id: str + status: WorkflowExecutionStatus + outputs: Mapping[str, Any] | None = None + error: str | None = None + elapsed_time: float + total_tokens: int + total_steps: int + created_at: int + finished_at: int | None + paused_nodes: Sequence[str] = Field(default_factory=list) + reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list) + + workflow_run_id: str + data: Data + + class AgentLogStreamResponse(StreamResponse): """ AgentLogStreamResponse entity diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d2d2fea4fb..d59f5125e3 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,9 +1,8 @@ import logging -from graphon.model_runtime.entities.message_entities import PromptMessage - from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation +from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py index 80d504ef1c..a583301f9b 100644 --- a/api/core/app/file_access/scope.py +++ b/api/core/app/file_access/scope.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass @@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None: @contextmanager -def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: +def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None] token = _current_file_access_scope.set(scope) try: yield diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index e09869f5f8..d5e6b04a4a 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -9,11 +9,10 @@ scope updates that matter to chat applications. import logging -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent - from core.workflow.system_variables import SystemVariableKey, get_system_text from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index c027f42788..9811f9f830 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index 8c8daf8712..bb9fc1b6fa 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore + from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent - from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 77c7bec67e..b60fe82ffe 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,12 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 278d0cb30b..5631caa1a5 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -1,22 +1,35 @@ from __future__ import annotations +from copy import deepcopy from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider - from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: + """Resolves and returns LLM credentials for a given provider and model. + + Fetched credentials are stored in :attr:`credentials_cache` and reused for + subsequent ``fetch`` calls for the same ``(provider_name, model_name)``. + Because of that cache, a single instance can return stale credentials after + the tenant or provider configuration changes (e.g. API key rotation). + + Do **not** keep one instance for the lifetime of a process or across + unrelated invocations. Create a new provider per request, workflow run, or + other bounded scope where up-to-date credentials matter. + """ + tenant_id: str provider_manager: ProviderManager + credentials_cache: dict[tuple[str, str], dict[str, Any]] def __init__( self, @@ -31,8 +44,12 @@ class DifyCredentialsProvider: user_id=run_context.user_id, ) self.provider_manager = provider_manager + self.credentials_cache = {} def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + if (provider_name, model_name) in self.credentials_cache: + return deepcopy(self.credentials_cache[(provider_name, model_name)]) + provider_configurations = self.provider_manager.get_configurations(self.tenant_id) provider_configuration = provider_configurations.get(provider_name) if not provider_configuration: @@ -47,6 +64,7 @@ class DifyCredentialsProvider: if credentials is None: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials) return credentials @@ -66,7 +84,8 @@ class DifyModelFactory: provider_manager=create_plugin_provider_manager( tenant_id=run_context.tenant_id, user_id=run_context.user_id, - ) + ), + enable_credentials_cache=True, ) self.model_manager = model_manager @@ -85,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro tenant_id=run_context.tenant_id, user_id=run_context.user_id, ) - model_manager = ModelManager(provider_manager=provider_manager) + model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True) return ( DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 0bb10190c4..b6039e1e4e 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,4 +1,3 @@ -from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import sessionmaker @@ -8,6 +7,7 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 10b9c36d3e..9e688589db 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,7 +1,6 @@ import logging import time -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,6 +17,7 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 6bb177fe02..e2e07ebaff 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,13 +4,6 @@ from collections.abc import Generator from threading import Thread from typing import Any, cast -from graphon.file import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -60,6 +53,13 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index 77310baf74..1dd713821f 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,9 +1,8 @@ from typing import TypedDict +from core.tools.signature import sign_tool_file from graphon.file import FileTransferMethod from graphon.file import helpers as file_helpers - -from core.tools.signature import sign_tool_file from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 8604235ef2..3a6f9d575a 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -9,17 +9,17 @@ import urllib.parse from collections.abc import Generator from typing import TYPE_CHECKING, Literal -from graphon.file import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from graphon.file.runtime import set_workflow_file_runtime - from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage +from graphon.file import FileTransferMethod +from graphon.file.protocols import WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime +from graphon.http.protocols import HttpResponseProtocol if TYPE_CHECKING: from graphon.file import File @@ -44,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): return dify_config.MULTIMODAL_SEND_FORMAT def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - return ssrf_proxy.get(url, follow_redirects=follow_redirects) + return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index c577ce0754..4a7918032e 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,17 +7,16 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final, override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.nodes.base.node import Node -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.llm import deduct_llm_quota, ensure_llm_quota_available -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance - if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 99e8015c0b..8b5a5b9d7f 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -12,10 +12,6 @@ from contextvars import Token from dataclasses import dataclass from typing import cast, final, override -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context @@ -28,6 +24,10 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index ada065a943..d521304615 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,6 +14,13 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -38,14 +45,6 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now @@ -350,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs - execution.exceptions_count = runtime_state.exceptions_count + execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count) def _update_node_execution( self, diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 3d8a7a54f3..9e3c187210 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,9 +6,6 @@ import re import threading from collections.abc import Iterable -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -18,6 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index a5297fa33a..f0dcb13b62 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,9 +3,6 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -31,6 +28,9 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService @@ -352,11 +352,11 @@ class DatasourceManager: raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") file_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.CUSTOM, + file_type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference(record_id=str(upload_file.id)), diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 9c22d5e67c..352e6bfd49 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Any, Literal, TypedDict -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject, I18nObjectDict +from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index c012e128f4..6a3f9e684a 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,11 +2,10 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type -from graphon.file import File, FileTransferMethod, FileType - from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index d304c982cd..04ae193396 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias -from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field +from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index a440829b46..bfa4f56915 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,7 +6,6 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -16,6 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 84d95c38c6..e99a131500 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,10 +1,11 @@ from collections.abc import Sequence from enum import StrEnum, auto +from pydantic import BaseModel, ConfigDict + from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity -from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d07f6f913a..38b87e2cd1 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -8,16 +8,6 @@ from collections.abc import Iterator, Sequence from json import JSONDecodeError from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -34,6 +24,16 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials( - self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None - ): + def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""): """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate - :param session: optional database session :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) - credential_record = s.execute(stmt).scalar_one_or_none() - # fix origin data + credential_record = session.execute(stmt).scalar_one_or_none() if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} @@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def _generate_provider_credential_name(self, session) -> str: """ @@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: - credential_name = self._generate_provider_credential_name(session) + credential_name = self._generate_provider_credential_name(pre_session) - credentials = self.validate_provider_credentials(credentials=credentials, session=session) + credentials = self.validate_provider_credentials(credentials=credentials) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( @@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel): session.flush() if not provider_record: - # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, @@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_provider_credential_name_exists( - credential_name=credential_name, session=session, exclude_id=credential_id + credential_name=credential_name, session=pre_session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") - credentials = self.validate_provider_credentials( - credentials=credentials, credential_id=credential_id, session=session - ) + credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, @@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel): ProviderCredential.provider_name.in_(self._get_provider_names()), ) - # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel): model: str, credentials: dict[str, Any], credential_id: str = "", - session: Session | None = None, ): """ Validate custom model credentials. @@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel): :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, @@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type, ) - credential_record = s.execute(stmt).scalar_one_or_none() + credential_record = session.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config @@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def create_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None @@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel): :param credentials: model credentials dict :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: if self._check_custom_model_credential_name_exists( - model=model, model_type=model_type, credential_name=credential_name, session=session + model=model, model_type=model_type, credential_name=credential_name, session=pre_session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( - model=model, model_type=model_type, session=session + model=model, model_type=model_type, session=pre_session ) - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, model=model, credentials=credentials, session=session - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: @@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel): session.add(credential) session.flush() - # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, @@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel): :param credential_id: credential id :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, - session=session, + session=pre_session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, - model=model, - credentials=credentials, - credential_id=credential_id, - session=session, - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( @@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel): raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 95431c0e01..72b29c2277 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,7 +3,6 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Any, Union -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -13,6 +12,7 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index f9e6099049..01139d07e2 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast import httpx @@ -14,7 +14,7 @@ class APIBasedExtensionRequestor: self.api_endpoint = api_endpoint self.api_key = api_key - def request(self, point: APIBasedExtensionPoint, params: dict): + def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]: """ Request the api. @@ -49,4 +49,4 @@ class APIBasedExtensionRequestor: if response.status_code != 200: raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}") - return cast(dict, response.json()) + return cast(dict[str, Any], response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index b79dbeb7e0..c08e319aac 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -21,8 +21,8 @@ class ExtensionModule(StrEnum): class ModuleExtension(BaseModel): extension_class: Any | None = None name: str - label: dict | None = None - form_schema: list | None = None + label: dict[str, Any] | None = None + form_schema: list[dict[str, Any]] | None = None builtin: bool = True position: int | None = None diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index f7a64cea1b..f404aa7286 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -13,7 +13,7 @@ class ExternalDataToolFactory: ) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]): + def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None: """ Validate the incoming form config data. diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 35bfcfb6a5..951e065b2c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,7 +4,6 @@ from threading import Lock from typing import Any import httpx -from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -14,6 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client +from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b96a9ce380..38864a1830 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -102,7 +102,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() + inputs_json_str = dumps_with_segments(inputs).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/creators.py b/api/core/helper/creators.py new file mode 100644 index 0000000000..b01e16f18a --- /dev/null +++ b/api/core/helper/creators.py @@ -0,0 +1,41 @@ +""" +Helper module for Creators Platform integration. + +Provides functionality to upload DSL files to the Creators Platform +and generate redirect URLs with OAuth authorization codes. +""" + +import logging +from urllib.parse import urlencode + +import httpx +from yarl import URL + +from configs import dify_config + +logger = logging.getLogger(__name__) + +creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL)) + + +def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str: + url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload") + response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30) + response.raise_for_status() + data = response.json() + claim_code = data.get("data", {}).get("claim_code") + if not claim_code: + raise ValueError("Creators Platform did not return a valid claim_code") + return claim_code + + +def get_redirect_url(user_account_id: str, claim_code: str) -> str: + base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/") + params: dict[str, str] = {"dsl_claim_code": claim_code} + client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "") + if client_id: + from services.oauth_server import OAuthServerService + + oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id) + params["oauth_code"] = oauth_code + return f"{base_url}?{urlencode(params)}" diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index a1e782a094..f169f247cf 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,14 +2,13 @@ import logging import secrets from typing import cast -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e38592bb7b..91e92712b7 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client from core.tools.errors import ToolSSRFError +from graphon.http.response import HttpResponse logger = logging.getLogger(__name__) @@ -267,4 +268,47 @@ class SSRFProxy: return patch(url=url, max_retries=max_retries, **kwargs) +def _to_graphon_http_response(response: httpx.Response) -> HttpResponse: + """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper.""" + return HttpResponse( + status_code=response.status_code, + headers=dict(response.headers), + content=response.content, + url=str(response.url) if response.url else None, + reason_phrase=response.reason_phrase, + fallback_text=response.text, + ) + + +class GraphonSSRFProxy: + """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``.""" + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs)) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs)) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs)) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs)) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs)) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs)) + + ssrf_proxy = SSRFProxy() +graphon_ssrf_proxy = GraphonSSRFProxy() diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 60f5434bc1..8bcb899b23 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,12 @@ +from typing import Any + from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel +from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): @@ -28,7 +30,7 @@ class FreeHostingQuota(HostingQuota): class HostingProvider(BaseModel): enabled: bool = False - credentials: dict | None = None + credentials: dict[str, Any] | None = None quota_unit: QuotaUnit | None = None quotas: list[HostingQuota] = [] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 9aaf85dc0f..b6e33396d1 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,7 +9,6 @@ from collections.abc import Mapping from typing import Any from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select, update from sqlalchemy.orm.exc import ObjectDeletedError @@ -35,6 +34,7 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account @@ -737,7 +737,7 @@ class IndexingRunner: def _update_document_index_status( document_id: str, after_indexing_status: IndexingStatus, - extra_update_params: dict[Any, Any] | None = None, + extra_update_params: Mapping[Any, Any] | None = None, ): """ Update the document indexing status. @@ -764,7 +764,7 @@ class IndexingRunner: db.session.commit() @staticmethod - def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]): + def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]): """ Update the document segment by document id. """ diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index aa258c9f89..af2611bb0b 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,14 +2,9 @@ import json import logging import re from collections.abc import Sequence -from typing import Protocol, TypedDict, cast +from typing import Any, NotRequired, Protocol, TypedDict, cast import json_repair -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from core.app.app_config.entities import ModelConfig @@ -23,8 +18,6 @@ from core.llm_generator.prompts import ( LLM_MODIFY_CODE_SYSTEM, LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, - SUGGESTED_QUESTIONS_MAX_TOKENS, - SUGGESTED_QUESTIONS_TEMPERATURE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) @@ -35,12 +28,47 @@ from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow logger = logging.getLogger(__name__) +class SuggestedQuestionsModelConfig(TypedDict): + provider: str + name: str + completion_params: NotRequired[dict[str, object]] + + +def _normalize_completion_params(completion_params: dict[str, object]) -> tuple[dict[str, object], list[str]]: + """ + Normalize raw completion params into invocation parameters and stop sequences. + + This mirrors the app-model access path by separating ``stop`` from provider + parameters before invocation, then drops non-positive token limits because + some plugin-backed models reject ``0`` after mapping ``max_tokens`` to their + provider-specific output-token field. + """ + normalized_parameters = dict(completion_params) + stop_value = normalized_parameters.pop("stop", []) + if isinstance(stop_value, list) and all(isinstance(item, str) for item in stop_value): + stop = stop_value + else: + stop = [] + + for token_limit_key in ("max_tokens", "max_output_tokens"): + token_limit = normalized_parameters.get(token_limit_key) + if isinstance(token_limit, int | float) and token_limit <= 0: + normalized_parameters.pop(token_limit_key, None) + + return normalized_parameters, stop + + class WorkflowServiceInterface(Protocol): def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: pass @@ -123,8 +151,15 @@ class LLMGenerator: return name @classmethod - def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]: - output_parser = SuggestedQuestionsAfterAnswerOutputParser() + def generate_suggested_questions_after_answer( + cls, + tenant_id: str, + histories: str, + *, + instruction_prompt: str | None = None, + model_config: object | None = None, + ) -> Sequence[str]: + output_parser = SuggestedQuestionsAfterAnswerOutputParser(instruction_prompt=instruction_prompt) format_instructions = output_parser.get_format_instructions() prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") @@ -133,10 +168,36 @@ class LLMGenerator: try: model_manager = ModelManager.for_tenant(tenant_id=tenant_id) - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - ) + configured_model = cast(dict[str, object], model_config) if isinstance(model_config, dict) else {} + provider = configured_model.get("provider") + model_name = configured_model.get("name") + use_configured_model = False + + if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name: + try: + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=provider, + model=model_name, + ) + use_configured_model = True + except Exception: + logger.warning( + "Failed to use configured suggested-questions model %s/%s, fallback to default model", + provider, + model_name, + exc_info=True, + ) + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + else: + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) except InvokeAuthorizationError: return [] @@ -145,19 +206,29 @@ class LLMGenerator: questions: Sequence[str] = [] try: + configured_completion_params = configured_model.get("completion_params") + if use_configured_model and isinstance(configured_completion_params, dict): + model_parameters, stop = _normalize_completion_params(configured_completion_params) + elif use_configured_model: + model_parameters = {} + stop = [] + else: + # Default-model generation keeps the built-in suggested-questions tuning. + model_parameters = { + "max_tokens": 2560, + "temperature": 0.0, + } + stop = [] + response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), - model_parameters={ - "max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS, - "temperature": SUGGESTED_QUESTIONS_TEMPERATURE, - }, + model_parameters=model_parameters, + stop=stop, stream=False, ) text_content = response.message.get_text_content() questions = output_parser.parse(text_content) if text_content else [] - except InvokeError: - questions = [] except Exception: logger.exception("Failed to generate suggested questions after answer") questions = [] @@ -533,7 +604,7 @@ class LLMGenerator: def __instruction_modify_common( tenant_id: str, model_config: ModelConfig, - last_run: dict | None, + last_run: dict[str, Any] | None, current: str | None, error_message: str | None, instruction: str, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 9bdca1e83b..d2e375626f 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,6 +5,11 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -21,11 +26,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance class ResponseFormat(StrEnum): @@ -202,7 +202,7 @@ def _handle_native_json_schema( structured_output_schema: Mapping, model_parameters: dict[str, Any], rules: list[ParameterRule], -): +) -> dict[str, Any]: """ Handle structured output for models with native JSON schema support. @@ -224,7 +224,7 @@ def _handle_native_json_schema( return model_parameters -def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]): +def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None: """ Set the appropriate response format parameter based on model rules. @@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema return {"schema": processed_schema, "name": "llm_response"} -def remove_additional_properties(schema: dict[str, Any]): +def remove_additional_properties(schema: dict[str, Any]) -> None: """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. @@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict[str, Any]): remove_additional_properties(item) -def convert_boolean_to_string(schema: dict): +def convert_boolean_to_string(schema: dict[str, Any]) -> None: """ Convert boolean type specifications to string in JSON schema. diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index eec771181f..7ac340926d 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -3,17 +3,28 @@ import logging import re from collections.abc import Sequence -from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.llm_generator.prompts import DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT logger = logging.getLogger(__name__) class SuggestedQuestionsAfterAnswerOutputParser: + def __init__(self, instruction_prompt: str | None = None) -> None: + self._instruction_prompt = self._build_instruction_prompt(instruction_prompt) + + @staticmethod + def _build_instruction_prompt(instruction_prompt: str | None) -> str: + if not instruction_prompt or not instruction_prompt.strip(): + return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + + return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].' + def get_format_instructions(self) -> str: - return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + return self._instruction_prompt def parse(self, text: str) -> Sequence[str]: - action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) + stripped_text = text.strip() + action_match = re.search(r"\[.*?\]", stripped_text, re.DOTALL) questions: list[str] = [] if action_match is not None: try: @@ -23,4 +34,6 @@ class SuggestedQuestionsAfterAnswerOutputParser: else: if isinstance(json_obj, list): questions = [question for question in json_obj if isinstance(question, str)] + elif stripped_text: + logger.warning("Failed to find suggested questions payload array in text: %r", stripped_text[:200]) return questions diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ee9a016c95..3c6f8c468a 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -1,5 +1,4 @@ # Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh -import os CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”. @@ -96,8 +95,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = ( ) -# Default prompt for suggested questions (can be overridden by environment variable) -_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = ( +# Default prompt and model parameters for suggested questions. +DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keep each question under 20 characters.\n" "MAKE SURE your output is the SAME language as the Assistant's latest response. " @@ -105,15 +104,6 @@ _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = ( '["question1","question2","question3"]\n' ) -# Environment variable override for suggested questions prompt -SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv( - "SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT -) - -# Configurable LLM parameters for suggested questions (can be overridden by environment variables) -SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256")) -SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0")) - GENERATOR_QA_PROMPT = ( " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" " in the long text. Please think step by step." diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 5c3cd0d8f8..acba3e666b 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -303,9 +303,16 @@ class StreamableHTTPTransport: if response.status_code == 404: if isinstance(message.root, JSONRPCRequest): + error_msg = ( + f"MCP server URL returned 404 Not Found: {self.url} " + "— verify the server URL is correct and the server is running" + if is_initialization + else "Session terminated by server" + ) self._send_session_terminated_error( ctx.server_to_client_queue, message.root.id, + message=error_msg, ) return @@ -381,12 +388,13 @@ class StreamableHTTPTransport: self, server_to_client_queue: ServerToClientQueue, request_id: RequestId, + message: str = "Session terminated by server", ): """Send a session terminated error response.""" jsonrpc_error = JSONRPCError( jsonrpc="2.0", id=request_id, - error=ErrorData(code=32600, message="Session terminated by server"), + error=ErrorData(code=32600, message=message), ) session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) server_to_client_queue.put(session_message) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 72171d1536..884610ca82 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,12 +3,11 @@ import logging from collections.abc import Mapping from typing import Any, NotRequired, TypedDict, cast -from graphon.variables.input_entities import VariableEntity, VariableEntityType - from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7e35044176..7b5a7635f1 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError +from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 5809d6f74a..d840ee213c 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,5 +1,14 @@ from collections.abc import Sequence +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -10,15 +19,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 36beb55d7f..457c888e33 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,20 +1,7 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence -from typing import IO, Any, Literal, Optional, Union, cast, overload - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from copy import deepcopy +from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config from core.entities import PluginCredentialType @@ -25,9 +12,24 @@ from core.errors.error import ProviderTokenNotInitError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_redis import redis_client +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from models.provider import ProviderType logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") class ModelInstance: @@ -35,11 +37,13 @@ class ModelInstance: Model instance class. """ - def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): + def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None: self.provider_model_bundle = provider_model_bundle self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider - self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + if credentials is None: + credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + self.credentials = credentials # Runtime LLM invocation fields. self.parameters: Mapping[str, Any] = {} self.stop: Sequence[str] = () @@ -169,7 +173,7 @@ class ModelInstance: return cast( Union[LLMResult, Generator], self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -194,7 +198,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -214,7 +218,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, texts=texts, @@ -236,7 +240,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, @@ -253,7 +257,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, texts=texts, @@ -278,7 +282,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, query=query, @@ -306,7 +310,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, + self.model_type_instance.invoke_multimodal_rerank, model=self.model_name, credentials=self.credentials, query=query, @@ -325,7 +329,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, text=text, @@ -341,7 +345,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, file=file, @@ -358,14 +362,14 @@ class ModelInstance: if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, content_text=content_text, voice=voice, ) - def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke @@ -433,8 +437,30 @@ class ModelInstance: class ModelManager: - def __init__(self, provider_manager: ProviderManager): + """Resolves :class:`ModelInstance` objects for a tenant and provider. + + When ``enable_credentials_cache`` is ``True``, resolved credentials for each + ``(tenant_id, provider, model_type, model)`` are stored in + ``_credentials_cache`` and reused. That can return **stale** credentials after + API keys or provider settings change, so a manager constructed with + ``enable_credentials_cache=True`` should not be kept for the lifetime of a + process or shared across unrelated work. Prefer a new manager per request, + workflow run, or similar bounded scope. + + The default is ``enable_credentials_cache=False``; in that mode the internal + credential cache is not populated, and each ``get_model_instance`` call + loads credentials from the current provider configuration. + """ + + def __init__( + self, + provider_manager: ProviderManager, + *, + enable_credentials_cache: bool = False, + ) -> None: self._provider_manager = provider_manager + self._credentials_cache: dict[tuple[str, str, str, str], Any] = {} + self._enable_credentials_cache = enable_credentials_cache @classmethod def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": @@ -462,8 +488,19 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - model_instance = ModelInstance(provider_model_bundle, model) - return model_instance + cred_cache_key = (tenant_id, provider, model_type.value, model) + + if cred_cache_key in self._credentials_cache: + return ModelInstance( + provider_model_bundle, + model, + deepcopy(self._credentials_cache[cred_cache_key]), + ) + + ret = ModelInstance(provider_model_bundle, model) + if self._enable_credentials_cache: + self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials) + return ret def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 732803b332..6e6e94502c 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,9 +1,8 @@ from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType - from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult +from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index fda00ac3b9..d78ce90aa1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,8 +1,8 @@ from enum import StrEnum -from pydantic import BaseModel, ValidationInfo, field_validator +from pydantic import BaseModel -from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import validate_project_name, validate_url class TracingProviderEnum(StrEnum): @@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel): return validate_project_name(v, default_name) -class ArizeConfig(BaseTracingConfig): - """ - Model class for Arize tracing config. - """ - - api_key: str | None = None - space_id: str | None = None - project: str | None = None - endpoint: str = "https://otlp.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://otlp.arize.com") - - -class PhoenixConfig(BaseTracingConfig): - """ - Model class for Phoenix tracing config. - """ - - api_key: str | None = None - project: str | None = None - endpoint: str = "https://app.phoenix.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://app.phoenix.arize.com") - - -class LangfuseConfig(BaseTracingConfig): - """ - Model class for Langfuse tracing config. - """ - - public_key: str - secret_key: str - host: str = "https://api.langfuse.com" - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://api.langfuse.com") - - -class LangSmithConfig(BaseTracingConfig): - """ - Model class for Langsmith tracing config. - """ - - api_key: str - project: str - endpoint: str = "https://api.smith.langchain.com" - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # LangSmith only allows HTTPS - return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) - - -class OpikConfig(BaseTracingConfig): - """ - Model class for Opik tracing config. - """ - - api_key: str | None = None - project: str | None = None - workspace: str | None = None - url: str = "https://www.comet.com/opik/api/" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "Default Project") - - @field_validator("url") - @classmethod - def url_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") - - -class WeaveConfig(BaseTracingConfig): - """ - Model class for Weave tracing config. - """ - - api_key: str - entity: str | None = None - project: str - endpoint: str = "https://trace.wandb.ai" - host: str | None = None - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # Weave only allows HTTPS for endpoint - return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - if v is not None and v.strip() != "": - return validate_url(v, v, allowed_schemes=("https", "http")) - return v - - -class AliyunConfig(BaseTracingConfig): - """ - Model class for Aliyun tracing config. - """ - - app_name: str = "dify_app" - license_key: str - endpoint: str - - @field_validator("app_name") - @classmethod - def app_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - @field_validator("license_key") - @classmethod - def license_key_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("License key cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # aliyun uses two URL formats, which may include a URL path - return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") - - -class TencentConfig(BaseTracingConfig): - """ - Tencent APM tracing config - """ - - token: str - endpoint: str - service_name: str - - @field_validator("token") - @classmethod - def token_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("Token cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") - - @field_validator("service_name") - @classmethod - def service_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - -class MLflowConfig(BaseTracingConfig): - """ - Model class for MLflow tracing config. - """ - - tracking_uri: str = "http://localhost:5000" - experiment_id: str = "0" # Default experiment id in MLflow is 0 - username: str | None = None - password: str | None = None - - @field_validator("tracking_uri") - @classmethod - def tracking_uri_validator(cls, v, info: ValidationInfo): - if isinstance(v, str) and v.startswith("databricks"): - raise ValueError( - "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." - ) - return validate_url_with_path(v, "http://localhost:5000") - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - -class DatabricksConfig(BaseTracingConfig): - """ - Model class for Databricks (Databricks-managed MLflow) tracing config. - """ - - experiment_id: str - host: str - client_id: str | None = None - client_secret: str | None = None - personal_access_token: str | None = None - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index cd63951537..e7ba6e502b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): def __getitem__(self, provider: str) -> TracingProviderConfigEntry: - match provider: - case TracingProviderEnum.LANGFUSE: - from core.ops.entities.config_entity import LangfuseConfig - from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + try: + match provider: + case TracingProviderEnum.LANGFUSE: + from dify_trace_langfuse.config import LangfuseConfig + from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace - return { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - } + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } - case TracingProviderEnum.LANGSMITH: - from core.ops.entities.config_entity import LangSmithConfig - from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + case TracingProviderEnum.LANGSMITH: + from dify_trace_langsmith.config import LangSmithConfig + from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace - return { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - } + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } - case TracingProviderEnum.OPIK: - from core.ops.entities.config_entity import OpikConfig - from core.ops.opik_trace.opik_trace import OpikDataTrace + case TracingProviderEnum.OPIK: + from dify_trace_opik.config import OpikConfig + from dify_trace_opik.opik_trace import OpikDataTrace - return { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": OpikDataTrace, - } + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } - case TracingProviderEnum.WEAVE: - from core.ops.entities.config_entity import WeaveConfig - from core.ops.weave_trace.weave_trace import WeaveDataTrace + case TracingProviderEnum.WEAVE: + from dify_trace_weave.config import WeaveConfig + from dify_trace_weave.weave_trace import WeaveDataTrace - return { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint", "host"], - "trace_instance": WeaveDataTrace, - } - case TracingProviderEnum.ARIZE: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import ArizeConfig + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint", "host"], + "trace_instance": WeaveDataTrace, + } + case TracingProviderEnum.ARIZE: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import ArizeConfig - return { - "config_class": ArizeConfig, - "secret_keys": ["api_key", "space_id"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.PHOENIX: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import PhoenixConfig + return { + "config_class": ArizeConfig, + "secret_keys": ["api_key", "space_id"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.PHOENIX: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import PhoenixConfig - return { - "config_class": PhoenixConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.ALIYUN: - from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace - from core.ops.entities.config_entity import AliyunConfig + return { + "config_class": PhoenixConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.ALIYUN: + from dify_trace_aliyun.aliyun_trace import AliyunDataTrace + from dify_trace_aliyun.config import AliyunConfig - return { - "config_class": AliyunConfig, - "secret_keys": ["license_key"], - "other_keys": ["endpoint", "app_name"], - "trace_instance": AliyunDataTrace, - } - case TracingProviderEnum.MLFLOW: - from core.ops.entities.config_entity import MLflowConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": AliyunConfig, + "secret_keys": ["license_key"], + "other_keys": ["endpoint", "app_name"], + "trace_instance": AliyunDataTrace, + } + case TracingProviderEnum.MLFLOW: + from dify_trace_mlflow.config import MLflowConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": MLflowConfig, - "secret_keys": ["password"], - "other_keys": ["tracking_uri", "experiment_id", "username"], - "trace_instance": MLflowDataTrace, - } - case TracingProviderEnum.DATABRICKS: - from core.ops.entities.config_entity import DatabricksConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": MLflowConfig, + "secret_keys": ["password"], + "other_keys": ["tracking_uri", "experiment_id", "username"], + "trace_instance": MLflowDataTrace, + } + case TracingProviderEnum.DATABRICKS: + from dify_trace_mlflow.config import DatabricksConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": DatabricksConfig, - "secret_keys": ["personal_access_token", "client_secret"], - "other_keys": ["host", "client_id", "experiment_id"], - "trace_instance": MLflowDataTrace, - } + return { + "config_class": DatabricksConfig, + "secret_keys": ["personal_access_token", "client_secret"], + "other_keys": ["host", "client_id", "experiment_id"], + "trace_instance": MLflowDataTrace, + } - case TracingProviderEnum.TENCENT: - from core.ops.entities.config_entity import TencentConfig - from core.ops.tencent_trace.tencent_trace import TencentDataTrace + case TracingProviderEnum.TENCENT: + from dify_trace_tencent.config import TencentConfig + from dify_trace_tencent.tencent_trace import TencentDataTrace - return { - "config_class": TencentConfig, - "secret_keys": ["token"], - "other_keys": ["endpoint", "service_name"], - "trace_instance": TencentDataTrace, - } + return { + "config_class": TencentConfig, + "secret_keys": ["token"], + "other_keys": ["endpoint", "service_name"], + "trace_instance": TencentDataTrace, + } - case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") + except ImportError: + raise ImportError(f"Provider {provider} is not installed.") provider_config_map = OpsTraceProviderConfigMap() diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index c715b9171c..c92438960a 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -1,20 +1,7 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator - -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -32,6 +19,19 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant @@ -226,7 +226,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): # invoke model response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) - def handle() -> Generator[dict, None, None]: + def handle() -> Generator[dict[str, Any], None, None]: for chunk in response: yield {"result": hexlify(chunk).decode("utf-8")} diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 9478997494..9550e49992 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,3 +1,4 @@ +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( @@ -8,8 +9,6 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) - -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index fd2094228a..03398873e3 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,12 +1,12 @@ from typing import Any -from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 4d28032a57..89e0e8881c 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,7 +3,6 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any -from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -14,6 +13,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index e0ddb746c7..257638ad77 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,8 +6,6 @@ from datetime import datetime from enum import StrEnum from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -18,6 +16,8 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 4a85952dcd..1474883204 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,6 +4,10 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -21,10 +25,6 @@ from graphon.nodes.parameter_extractor.entities import ( from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7f36560b49..9ee8469892 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,14 +5,6 @@ from collections.abc import Callable, Generator from typing import Any, cast import httpx -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -37,6 +29,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 703af63f7c..47608bdfa6 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,13 +2,6 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -20,6 +13,12 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 22c846b6de..4e66d58b5e 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -6,13 +6,6 @@ from collections.abc import Generator, Iterable, Sequence from threading import Lock from typing import IO, Any, Union -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime from pydantic import ValidationError from redis import RedisError @@ -21,6 +14,13 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime): if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") file_name = ( - provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us ) elif icon_type.lower() == "icon_small_dark": if not provider_schema.icon_small_dark: raise ValueError(f"Provider {provider} does not have small dark icon.") file_name = ( - provider_schema.icon_small_dark.zh_Hans + provider_schema.icon_small_dark.zh_hans if lang.lower() == "zh_hans" - else provider_schema.icon_small_dark.en_US + else provider_schema.icon_small_dark.en_us ) else: raise ValueError(f"Unsupported icon type: {icon_type}.") diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 4b29a6fc56..35abd2ae8c 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -2,9 +2,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory - from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory if TYPE_CHECKING: from core.model_manager import ModelManager diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 90350f8400..12d8e282b2 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,8 +1,7 @@ from typing import Any -from graphon.file import File - from core.tools.entities.tool_entities import ToolSelector +from graphon.file import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 19b5e9223a..24e05ef865 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,13 @@ from collections.abc import Mapping, Sequence from typing import cast +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -13,14 +20,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser - class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 9be70199b7..7c6280fe93 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,17 +1,16 @@ from typing import cast -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 4539ae9f11..6ff2f44cdc 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,12 +1,11 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d4e17613a2..1665bdeb52 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,6 +4,12 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, TypedDict, cast +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -13,13 +19,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: @@ -313,7 +312,7 @@ class SimplePromptTransform(PromptTransform): return prompt_message - def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str): + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict[str, Any]: """ Get simple prompt rule. :param app_mode: app mode @@ -325,7 +324,7 @@ class SimplePromptTransform(PromptTransform): # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return cast(dict, prompt_file_contents[prompt_file_name]) + return cast(dict[str, Any], prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -338,7 +337,7 @@ class SimplePromptTransform(PromptTransform): # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return cast(dict, content) + return cast(dict[str, Any], content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index dbda749925..ba76eb0c4e 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from typing import Any, cast +from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) -from core.prompt.simple_prompt_transform import ModelMode - class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 39ef31632e..8969825be4 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -6,14 +6,6 @@ from collections.abc import Sequence from json import JSONDecodeError from typing import TYPE_CHECKING, Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -41,6 +33,14 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, @@ -70,12 +70,32 @@ class ProviderManager: Request-bound managers may carry caller identity in that runtime, and the resulting ``ProviderConfiguration`` objects must reuse it for downstream model-type and schema lookups. + + Configuration assembly is cached per manager instance so call chains that + share one request-scoped manager can reuse the same provider graph instead + of rebuilding it for every lookup. Call ``clear_configurations_cache()`` + when a long-lived manager needs to observe writes performed within the same + instance scope. """ + decoding_rsa_key: Any | None + decoding_cipher_rsa: Any | None + _model_runtime: ModelRuntime + _configurations_cache: dict[str, ProviderConfigurations] + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None self._model_runtime = model_runtime + self._configurations_cache = {} + + def clear_configurations_cache(self, tenant_id: str | None = None) -> None: + """Drop assembled provider configurations cached on this manager instance.""" + if tenant_id is None: + self._configurations_cache.clear() + return + + self._configurations_cache.pop(tenant_id, None) def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -114,6 +134,10 @@ class ProviderManager: :param tenant_id: :return: """ + cached_configurations = self._configurations_cache.get(tenant_id) + if cached_configurations is not None: + return cached_configurations + # Get all provider records of the workspace provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) @@ -273,6 +297,8 @@ class ProviderManager: provider_configurations[str(provider_id_entity)] = provider_configuration + self._configurations_cache[tenant_id] = provider_configurations + # Return the encapsulated object return provider_configurations diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 9ce91f52ff..ca530748ed 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,8 +1,5 @@ from typing import TypedDict -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError - from core.model_manager import ModelInstance, ModelManager from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType @@ -11,6 +8,8 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ed264878d3..392af351b6 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -139,8 +139,10 @@ class Jieba(BaseKeyword): "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table - keyword_data_source_type = dataset_keyword_table.data_source_type + keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file" if keyword_data_source_type == "database": + if dataset_keyword_table is None: + return dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: @@ -154,7 +156,8 @@ class Jieba(BaseKeyword): if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return dict(keyword_table_dict["__data__"]["table"]) + data: Any = keyword_table_dict["__data__"] + return dict(data["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 84f35c25f8..2af8238cc4 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,4 +1,5 @@ import re +from collections.abc import Callable from operator import itemgetter from typing import cast @@ -80,12 +81,14 @@ class JiebaKeywordTableHandler: def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs): # Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable. - top_k = kwargs.pop("topK", top_k) + top_k = cast(int | None, kwargs.pop("topK", top_k)) + if top_k is None: + top_k = 20 cut = getattr(jieba, "cut", None) if self._lcut: tokens = self._lcut(sentence) elif callable(cut): - tokens = list(cut(sentence)) + tokens = list(cast(Callable[[str], list[str]], cut)(sentence)) else: tokens = re.findall(r"\w+", sentence) @@ -106,9 +109,9 @@ class JiebaKeywordTableHandler: """Extract keywords with JIEBA tfidf.""" keywords = self._tfidf.extract_tags( sentence=text, - topK=max_keywords_per_chunk, + topK=max_keywords_per_chunk or 10, ) - # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. keywords = cast(list[str], keywords) return set(self._expand_tokens_with_subtokens(set(keywords))) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index f978e072f3..c60d19045a 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired, TypedDict from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only @@ -24,6 +23,7 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, @@ -158,7 +158,7 @@ class RetrievalService: ) if futures: - for future in concurrent.futures.as_completed(futures, timeout=3600): + for _ in concurrent.futures.as_completed(futures, timeout=3600): if exceptions: for f in futures: f.cancel() @@ -195,6 +195,23 @@ class RetrievalService: ) return all_documents + @classmethod + def _filter_documents_by_vector_score_threshold( + cls, documents: list[Document], score_threshold: float | None + ) -> list[Document]: + """Keep documents whose stored retrieval score meets the threshold. + + Used when hybrid search skips early vector thresholding but no rerank + runner applies a threshold afterward (same rule as ``calculate_vector_score``). + """ + if score_threshold is None: + return documents + return [ + document + for document in documents + if document.metadata and document.metadata.get("score", 0) >= score_threshold + ] + @classmethod def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: """Deduplicate documents in O(n) while preserving first-seen order. @@ -294,13 +311,20 @@ class RetrievalService: vector = Vector(dataset=dataset) documents = [] + # Hybrid search merges keyword / full-text / vector hits and then reranks + # (weighted fusion or reranking model). Applying the user score threshold at + # vector retrieval time uses embedding similarity, which is not comparable to + # reranked or fused scores and incorrectly drops high-quality chunks (#35233). + embedding_score_threshold = ( + 0.0 if retrieval_method == RetrievalMethod.HYBRID_SEARCH else score_threshold + ) if query_type == QueryType.TEXT_QUERY: documents.extend( vector.search_by_vector( query, search_type="similarity_score_threshold", top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -312,7 +336,7 @@ class RetrievalService: vector.search_by_file( file_id=query, top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -527,6 +551,7 @@ class RetrievalService: child_index_nodes = session.execute(child_chunk_stmt).scalars().all() for i in child_index_nodes: + assert i.index_node_id segment_ids.append(i.segment_id) if i.segment_id in child_chunk_map: child_chunk_map[i.segment_id].append(i) @@ -844,6 +869,10 @@ class RetrievalService: top_n=top_k, query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, ) + if not data_post_processor.rerank_runner and score_threshold: + all_documents_item = self._filter_documents_by_vector_score_threshold( + all_documents_item, score_threshold + ) all_documents.extend(all_documents_item) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index dddd5fc994..9575377174 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,7 +4,6 @@ import time from abc import ABC, abstractmethod from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -19,6 +18,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile @@ -39,6 +39,58 @@ class AbstractVectorFactory(ABC): return index_struct_dict +class _LazyEmbeddings(Embeddings): + """Lazy proxy that defers materializing the real embedding model. + + Constructing the real embeddings (via ``ModelManager.get_model_instance``) + transitively calls ``FeatureService.get_features`` → ``BillingService`` + HTTP GETs (see ``provider_manager.py``). Cleanup paths + (``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings + at all, so deferring this until an ``embed_*`` method is actually invoked + keeps cleanup tasks resilient to transient billing-API failures and avoids + leaving stranded ``document_segments`` / ``child_chunks`` whenever billing + hiccups. + + Existing callers that perform create / search operations are unaffected: + the first ``embed_*`` call materializes the underlying model and the + behavior is identical from that point on. + """ + + def __init__(self, dataset: Dataset): + self._dataset = dataset + self._real: Embeddings | None = None + + def _ensure(self) -> Embeddings: + if self._real is None: + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) + embedding_model = model_manager.get_model_instance( + tenant_id=self._dataset.tenant_id, + provider=self._dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=self._dataset.embedding_model, + ) + self._real = CacheEmbedding(embedding_model) + return self._real + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self._ensure().embed_documents(texts) + + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: + return self._ensure().embed_multimodal_documents(multimodel_documents) + + def embed_query(self, text: str) -> list[float]: + return self._ensure().embed_query(text) + + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: + return self._ensure().embed_multimodal_query(multimodel_document) + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + return await self._ensure().aembed_documents(texts) + + async def aembed_query(self, text: str) -> list[float]: + return await self._ensure().aembed_query(text) + + class Vector: def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: @@ -60,7 +112,11 @@ class Vector: "original_chunk_id", ] self._dataset = dataset - self._embeddings = self._get_embeddings() + # Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists) + # never transitively trigger billing API calls during ``Vector(dataset)`` + # construction. The real embedding model is materialized only when an + # ``embed_*`` method is actually invoked (i.e. create / search paths). + self._embeddings: Embeddings = _LazyEmbeddings(dataset) self._attributes = attributes self._vector_processor = self._init_vector() diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 8e9ebdd17a..78305a6ac0 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,14 +3,15 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding +from models.enums import SegmentType class DatasetDocumentStore: @@ -127,6 +128,7 @@ class DatasetDocumentStore: if save_child: if doc.children: for position, child in enumerate(doc.children, start=1): + assert self._document_id child_segment = ChildChunk( tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, @@ -137,7 +139,7 @@ class DatasetDocumentStore: index_node_hash=child.metadata.get("doc_hash"), content=child.page_content, word_count=len(child.page_content), - type="automatic", + type=SegmentType.AUTOMATIC, created_by=self._user_id, ) db.session.add(child_segment) @@ -163,6 +165,7 @@ class DatasetDocumentStore: ) # add new child chunks for position, child in enumerate(doc.children, start=1): + assert self._document_id child_segment = ChildChunk( tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, @@ -173,7 +176,7 @@ class DatasetDocumentStore: index_node_hash=child.metadata.get("doc_hash"), content=child.page_content, word_count=len(child.page_content), - type="automatic", + type=SegmentType.AUTOMATIC, created_by=self._user_id, ) db.session.add(child_segment) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index f5f5f541da..a9995778f7 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,8 +4,6 @@ import pickle from typing import Any, cast import numpy as np -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -15,6 +13,8 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding @@ -106,7 +106,7 @@ class CacheEmbedding(Embeddings): return text_embeddings - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" # use doc embedding cache or store if not exists multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index ab190d2c42..7ae5c09ab7 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -11,7 +11,7 @@ class Embeddings(ABC): raise NotImplementedError @abstractmethod - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" raise NotImplementedError diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 449be6a448..b679edab36 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -94,16 +94,18 @@ class ExtractProcessor: cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE: + upload_file = extract_setting.upload_file with tempfile.TemporaryDirectory() as temp_dir: + upload_file = extract_setting.upload_file if not file_path: - assert extract_setting.upload_file is not None, "upload_file is required" - upload_file: UploadFile = extract_setting.upload_file + assert upload_file is not None, "upload_file is required" suffix = Path(upload_file.key).suffix # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() + assert upload_file is not None, "upload_file is required" etl_type = dify_config.ETL_TYPE extractor: BaseExtractor | None = None if etl_type == "Unstructured": @@ -113,6 +115,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( @@ -123,6 +126,7 @@ class ExtractProcessor: elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".doc": extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -149,12 +153,14 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 89bdd56a6c..556158cf00 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -174,21 +174,25 @@ class FirecrawlApp: return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}" def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.post(url, headers=headers, json=data) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.get(url, headers=headers) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _handle_error(self, response, action): diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 052fca930d..0330a43b28 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -3,6 +3,7 @@ Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). """ +import inspect import logging import mimetypes import os @@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor): file_path: Path to the file to load. """ + _closed: bool + def __init__(self, file_path: str, tenant_id: str, user_id: str): """Initialize with file path.""" + self._closed = False self.file_path = file_path self.tenant_id = tenant_id self.user_id = user_id @@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor): elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") + def close(self) -> None: + """Best-effort cleanup for downloaded temporary files.""" + if getattr(self, "_closed", False): + return + + self._closed = True + temp_file = getattr(self, "temp_file", None) + if temp_file is None: + return + + try: + close_result = temp_file.close() + if inspect.isawaitable(close_result): + close_awaitable = getattr(close_result, "close", None) + if callable(close_awaitable): + close_awaitable() + except Exception: + logger.debug("Failed to cleanup downloaded word temp file", exc_info=True) + def __del__(self): - if hasattr(self, "temp_file"): - self.temp_file.close() + self.close() def extract(self) -> list[Document]: """Load given path as single page.""" diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a487c49053..7ffa9afafd 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -7,16 +7,6 @@ from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -43,6 +33,16 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account @@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): try: # Create File object directly (similar to DatasetRetrieval) file_obj = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 087736d0b0..4ebf095904 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from graphon.file import File from pydantic import BaseModel, Field +from graphon.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index a8d37845a5..bce08f998f 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,8 +1,5 @@ import base64 -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult - from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -10,6 +7,8 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 49123e13d0..d0732b269a 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,7 +2,6 @@ import math from collections import Counter import numpy as np -from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -12,6 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner +from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8ebc840b99..5631b3a921 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,11 +9,6 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select, update from sqlalchemy.orm import sessionmaker @@ -69,6 +64,11 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile @@ -517,11 +517,11 @@ class DatasetRetrieval: if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/retrieval/output_parser/react_output.py b/api/core/rag/retrieval/output_parser/react_output.py index 9a14d41716..29abae4280 100644 --- a/api/core/rag/retrieval/output_parser/react_output.py +++ b/api/core/rag/retrieval/output_parser/react_output.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import NamedTuple, Union +from typing import Any, NamedTuple, Union @dataclass @@ -10,7 +10,7 @@ class ReactAction: tool: str """The name of the Tool to execute.""" - tool_input: Union[str, dict] + tool_input: Union[str, dict[str, Any]] """The input to pass in to the Tool.""" log: str """Additional information to log about the action.""" @@ -19,7 +19,7 @@ class ReactAction: class ReactFinish(NamedTuple): """The final return value of an ReactFinish.""" - return_values: dict + return_values: dict[str, Any] """Dictionary of return values.""" log: str """Additional information to log about the return value""" diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index dce7b6226c..dd17545c86 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,10 +1,9 @@ from typing import Union -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: @@ -29,10 +28,10 @@ class FunctionCallMultiDatasetRouter: SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result: LLMResult = model_instance.invoke_llm( + result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] prompt_messages=prompt_messages, tools=dataset_tools, - stream=False, + stream=False, # pyright: ignore[reportArgumentType] model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) usage = result.usage or LLMUsage.empty_usage() diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index dd280cdf6a..21a9d04f7f 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Union - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota @@ -12,6 +8,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -139,7 +138,7 @@ class ReactMultiDatasetRouter: def _invoke_llm( self, - completion_param: dict, + completion_param: dict[str, Any], model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: list[str], diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 3383c7f3bd..52c9a02f97 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -4,13 +4,12 @@ from __future__ import annotations import codecs import re -from collections.abc import Collection +from collections.abc import Set as AbstractSet from typing import Any, Literal -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter +from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @@ -22,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder[T: EnhanceRecursiveCharacterTextSplitter]( cls: type[T], embedding_model_instance: ModelInstance | None, - allowed_special: Literal["all"] | set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ) -> T: def _token_encoder(texts: list[str]) -> list[int]: @@ -41,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] + _ = _token_encoder # kept for future token-length wiring return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 8977611f93..a8d9013fbc 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -4,7 +4,8 @@ import copy import logging import re from abc import ABC, abstractmethod -from collections.abc import Callable, Collection, Iterable, Sequence, Set +from collections.abc import Callable, Iterable, Sequence +from collections.abc import Set as AbstractSet from dataclasses import dataclass from typing import Any, Literal @@ -63,7 +64,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: list[dict[str, Any]] | None = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter): self, encoding_name: str = "gpt2", model_name: str | None = None, - allowed_special: Literal["all"] | Set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ): """Create a new TextSplitter.""" @@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter): else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc - self._allowed_special = allowed_special - self._disallowed_special = disallowed_special + self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special + self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index b07c63fdf0..e87d1cd6b2 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -7,11 +7,11 @@ providing improved performance by offloading database operations to background w import logging -from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index cdb3af01a8..2451563317 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -8,7 +8,6 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,6 +15,7 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from graphon.entities import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index ce3ad15759..4e83e70799 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 72d9394149..740d727e26 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,13 +4,11 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, @@ -19,6 +17,8 @@ from core.workflow.human_input_compat import ( InteractiveSurfaceDeliveryMethod, is_human_input_webapp_enabled, ) +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index d74cc8f231..6be3902317 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -5,13 +5,13 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 13e885672a..b036687bc9 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,10 +10,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any import psycopg2.errors -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -23,6 +19,10 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index 6e26664ac2..e267c1abd9 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -254,7 +254,7 @@ def resolve_dify_schema_refs( return resolver.resolve(schema) -def _remove_metadata_fields(schema: dict) -> dict: +def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]: """ Remove metadata fields from schema that shouldn't be included in resolved output diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index e539074303..95660ab93b 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,15 +2,14 @@ import io from collections.abc import Generator from typing import Any -from graphon.file import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f49c669fe0..ac3820f1ab 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,13 +2,12 @@ import io from collections.abc import Generator from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 14af63a962..d41503e1e6 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,12 +1,11 @@ from __future__ import annotations -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage - from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 0a2c37c563..168e5f4493 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,7 +6,6 @@ from typing import Any, Union from urllib.parse import urlencode import httpx -from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -14,6 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError +from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 410ec72baf..42a88c0003 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,7 +2,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -10,6 +9,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 10710c4376..4e07b7157a 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from typing import Any from pydantic import BaseModel, Field @@ -26,6 +27,6 @@ class ApiToolBundle(BaseModel): # icon icon: str | None = None # openapi operation - openapi: dict + openapi: dict[str, Any] # output schema output_schema: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 4c3efd6ff9..2b26832b44 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError): pass +class ApiToolProviderNotFoundError(ValueError): + error_code = "api_tool_provider_not_found" + provider_name: str + tenant_id: str + + def __init__(self, provider_name: str, tenant_id: str): + self.provider_name = provider_name + self.tenant_id = tenant_id + super().__init__(f"api provider {provider_name} does not exist") + + class WorkflowToolHumanInputNotSupportedError(BaseHTTPException): error_code = "workflow_tool_human_input_not_supported" description = "Workflow with Human Input nodes cannot be published as a workflow tool." diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index f6d09472b3..00fc8a8282 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,8 +6,6 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata - from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -23,6 +21,7 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 1afaa9cfaf..3caacb8706 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,7 +7,6 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast -from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -33,6 +32,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile @@ -47,7 +47,7 @@ class ToolEngine: @staticmethod def agent_invoke( tool: Tool, - tool_parameters: Union[str, dict], + tool_parameters: Union[str, dict[str, Any]], user_id: str, tenant_id: str, message: Message, @@ -85,7 +85,7 @@ class ToolEngine: invocation_meta_dict: dict[str, ToolInvokeMeta] = {} def message_callback( - invocation_meta_dict: dict[str, Any], + invocation_meta_dict: dict[str, ToolInvokeMeta], messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None], ): for message in messages: diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index d8674b3af9..c87e8a3ae0 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -9,7 +9,6 @@ from mimetypes import guess_extension, guess_type from uuid import uuid4 import httpx -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from sqlalchemy import select from configs import dify_config @@ -17,6 +16,7 @@ from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile @@ -28,7 +28,7 @@ class ToolFileManager: def _build_graph_file_reference(tool_file: ToolFile) -> File: extension = guess_extension(tool_file.mimetype) or ".bin" return File( - type=get_file_type_by_mime_type(tool_file.mimetype), + file_type=get_file_type_by_mime_type(tool_file.mimetype), transfer_method=FileTransferMethod.TOOL_FILE, remote_url=tool_file.original_url, reference=build_file_reference(record_id=str(tool_file.id)), diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 58190d1089..d8969a3391 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,5 @@ from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -19,10 +20,18 @@ class ToolLabelManager: return list(set(tool_labels)) @classmethod - def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + def update_tool_labels( + cls, controller: ToolProviderController, labels: list[str], session: Session | None = None + ) -> None: """ Update tool labels + + :param controller: tool provider controller + :param labels: list of tool labels + :param session: database session, if None, a new session will be created + :return: None """ + labels = cls.filter_tool_labels(labels) if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): @@ -30,26 +39,46 @@ class ToolLabelManager: else: raise ValueError("Unsupported tool type") + if session is not None: + cls._update_tool_labels_logics(session, provider_id, controller, labels) + else: + with sessionmaker(db.engine).begin() as _session: + cls._update_tool_labels_logics(_session, provider_id, controller, labels) + + @classmethod + def _update_tool_labels_logics( + cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str] + ) -> None: + """ + Update tool labels logics + + :param session: database session + :param provider_id: tool provider ID + :param controller: tool provider controller + :param labels: list of tool labels + :return: None + """ + # delete old labels - db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) + _ = session.execute( + delete(ToolLabelBinding).where( + ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type + ) + ) # insert new labels for label in labels: - db.session.add( - ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type, - label_name=label, - ) - ) - - db.session.commit() + session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label)) @classmethod def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: """ Get tool labels + + :param controller: tool provider controller + :return: list of tool labels (str) """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id elif isinstance(controller, BuiltinToolProviderController): @@ -60,9 +89,11 @@ class ToolLabelManager: ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type, ) - labels = db.session.scalars(stmt).all() - return list(labels) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + labels: list[str] = list(_session.scalars(stmt).all()) + + return labels @classmethod def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: @@ -78,16 +109,22 @@ class ToolLabelManager: if not tool_providers: return {} + provider_ids: list[str] = [] + provider_types: set[str] = set() + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - - provider_ids = [] - for controller in tool_providers: - assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) + provider_types.add(controller.provider_type) - labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() + labels: list[ToolLabelBinding] = [] + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + stmt = select(ToolLabelBinding).where( + ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types)) + ) + labels = list(_session.scalars(stmt).all()) tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index be13d40f3e..87cf6d7085 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -8,7 +8,6 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa -from graphon.runtime import VariablePool from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import Session @@ -29,14 +28,13 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db +from graphon.runtime import VariablePool from models.provider_ids import ToolProviderID from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -62,6 +60,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -1083,7 +1082,12 @@ class ToolManager: continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.value + if not isinstance(variable_selector, list) or not all( + isinstance(selector_part, str) for selector_part in variable_selector + ): + raise ToolParameterError("Variable tool input must be a variable selector") + variable = variable_pool.get(variable_selector) if variable is None: raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 03e3c5918d..b6890b2611 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,7 +1,6 @@ import threading from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -15,6 +14,7 @@ from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 6a189fa6aa..0d1dc7273b 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from pydantic import BaseModel, Field from sqlalchemy import select @@ -39,7 +39,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset_id: str user_id: str | None = None retrieve_config: DatasetRetrieveConfigEntity - inputs: dict + inputs: dict[str, Any] @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 81c85bc90d..5679466cbc 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -9,11 +9,11 @@ from uuid import UUID import numpy as np import pytz -from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account @@ -41,6 +41,10 @@ def safe_json_value(v): return v.hex() elif isinstance(v, memoryview): return v.tobytes().hex() + elif isinstance(v, np.integer): + return int(v) + elif isinstance(v, np.floating): + return float(v) elif isinstance(v, np.ndarray): return v.tolist() elif isinstance(v, dict): diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8d6f83dc07..a3623d4ecd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,6 +8,9 @@ import json from decimal import Decimal from typing import cast +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -18,12 +21,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder - -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_encryption.py similarity index 57% rename from api/core/tools/utils/system_oauth_encryption.py rename to api/core/tools/utils/system_encryption.py index 6b7007842d..ca7e6a13fe 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_encryption.py @@ -14,23 +14,23 @@ from configs import dify_config logger = logging.getLogger(__name__) -class OAuthEncryptionError(Exception): - """OAuth encryption/decryption specific error""" +class EncryptionError(Exception): + """Encryption/decryption specific error""" pass -class SystemOAuthEncrypter: +class SystemEncrypter: """ - A simple OAuth parameters encrypter using AES-CBC encryption. + A simple parameters encrypter using AES-CBC encryption. - This class provides methods to encrypt and decrypt OAuth parameters + This class provides methods to encrypt and decrypt parameters using AES-CBC mode with a key derived from the application's SECRET_KEY. """ def __init__(self, secret_key: str | None = None): """ - Initialize the OAuth encrypter. + Initialize the encrypter. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY @@ -43,19 +43,19 @@ class SystemOAuthEncrypter: # Generate a fixed 256-bit key using SHA-256 self.key = hashlib.sha256(secret_key.encode()).digest() - def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + def encrypt_params(self, params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters. + Encrypt parameters. Args: - oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} Returns: Base64-encoded encrypted string Raises: - OAuthEncryptionError: If encryption fails - ValueError: If oauth_params is invalid + EncryptionError: If encryption fails + ValueError: If params is invalid """ try: @@ -66,7 +66,7 @@ class SystemOAuthEncrypter: cipher = AES.new(self.key, AES.MODE_CBC, iv) # Encrypt data - padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size) encrypted_data = cipher.encrypt(padded_data) # Combine IV and encrypted data @@ -76,20 +76,20 @@ class SystemOAuthEncrypter: return base64.b64encode(combined).decode() except Exception as e: - raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + raise EncryptionError(f"Encryption failed: {str(e)}") from e - def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters. + Decrypt parameters. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary Raises: - OAuthEncryptionError: If decryption fails + EncryptionError: If decryption fails ValueError: If encrypted_data is invalid """ if not isinstance(encrypted_data, str): @@ -118,70 +118,70 @@ class SystemOAuthEncrypter: unpadded_data = unpad(decrypted_data, AES.block_size) # Parse JSON - oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) - if not isinstance(oauth_params, dict): + if not isinstance(params, dict): raise ValueError("Decrypted data is not a valid dictionary") - return oauth_params + return params except Exception as e: - raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + raise EncryptionError(f"Decryption failed: {str(e)}") from e # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: +def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter: """ - Create an OAuth encrypter instance. + Create an encrypter instance. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - return SystemOAuthEncrypter(secret_key=secret_key) + return SystemEncrypter(secret_key=secret_key) # Global encrypter instance (for backward compatibility) -_oauth_encrypter: SystemOAuthEncrypter | None = None +_encrypter: SystemEncrypter | None = None -def get_system_oauth_encrypter() -> SystemOAuthEncrypter: +def get_system_encrypter() -> SystemEncrypter: """ - Get the global OAuth encrypter instance. + Get the global encrypter instance. Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - global _oauth_encrypter - if _oauth_encrypter is None: - _oauth_encrypter = SystemOAuthEncrypter() - return _oauth_encrypter + global _encrypter + if _encrypter is None: + _encrypter = SystemEncrypter() + return _encrypter # Convenience functions for backward compatibility -def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: +def encrypt_system_params(params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters using the global encrypter. + Encrypt parameters using the global encrypter. Args: - oauth_params: OAuth parameters dictionary + params: Parameters dictionary Returns: Base64-encoded encrypted string """ - return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + return get_system_encrypter().encrypt_params(params) -def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: +def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters using the global encrypter. + Decrypt parameters using the global encrypter. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary """ - return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) + return get_system_encrypter().decrypt_params(encrypted_data) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index ed3ed3e0de..94a2c0427b 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -105,7 +105,7 @@ class Article: def extract_using_readabilipy(html: str): - json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True) + json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False) article = Article( title=json_article.get("title") or "", author=json_article.get("byline") or "", diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 2159eb8638..45718cadb6 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,13 +1,12 @@ from collections.abc import Mapping, Sequence from typing import Any +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError - class WorkflowToolConfigurationUtils: @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index a01004448a..5905fd919e 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Mapping -from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -25,6 +24,7 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 552fbab1a4..cd8c6352b5 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,8 +5,6 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -22,6 +20,8 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping @@ -277,7 +277,7 @@ class WorkflowTool(Tool): session.expunge(app) return app - def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]: + def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, str | None]]]: """ transform the tool parameters @@ -355,9 +355,12 @@ class WorkflowTool(Tool): return result, files - def _update_file_mapping(self, file_dict: dict[str, Any]): + def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]: file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) - transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) + transfer_method_value = file_dict.get("transfer_method") + if not isinstance(transfer_method_value, str): + raise ValueError("Workflow file mapping is missing a valid transfer_method") + transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.TOOL_FILE: file_dict["tool_file_id"] = file_id diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 61d1cd8540..24c1271488 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,7 +8,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -28,6 +27,7 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py similarity index 74% rename from api/core/workflow/human_input_compat.py rename to api/core/workflow/human_input_adapter.py index c95516a240..4b765e6aea 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_adapter.py @@ -1,8 +1,8 @@ -"""Workflow-layer adapters for legacy human-input payload keys. +"""Workflow-to-Graphon adapters for persisted node payloads. -Stored workflow graphs and editor payloads may still use Dify-specific human -input recipient keys. Normalize them here before handing configs to -`graphon` so graph-owned models only see graph-neutral field names. +Stored workflow graphs and editor payloads still contain a small set of +Dify-owned field spellings and value shapes. Adapt them here before handing the +payload to Graphon so Graphon-owned models only see current contracts. """ from __future__ import annotations @@ -14,12 +14,13 @@ from typing import Annotated, Any, ClassVar, Literal import bleach import markdown +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.variable_template_parser import VariableTemplateParser from graphon.runtime import VariablePool from graphon.variables.consts import SELECTORS_LENGTH -from markdown.extensions.tables import TableExtension -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter class DeliveryMethodType(enum.StrEnum): @@ -184,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None: return None -def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") @@ -214,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: - normalized = normalize_human_input_node_data_for_graph(node_data) + normalized = adapt_human_input_node_data_for_graph(node_data) raw_delivery_methods = normalized.get("delivery_methods") if not isinstance(raw_delivery_methods, list): return [] @@ -228,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b return False -def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") - if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: - return normalized - return normalize_human_input_node_data_for_graph(normalized) + node_type = normalized.get("type") + if node_type == BuiltinNodeTypes.HUMAN_INPUT: + return adapt_human_input_node_data_for_graph(normalized) + if node_type == BuiltinNodeTypes.TOOL: + return _adapt_tool_node_data_for_graph(normalized) + return normalized -def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_config) if normalized is None: raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") @@ -247,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) if data_mapping is None: return normalized - normalized["data"] = normalize_node_data_for_graph(data_mapping) + normalized["data"] = adapt_node_data_for_graph(data_mapping) return normalized +def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(node_data) + + raw_tool_configurations = normalized.get("tool_configurations") + if not isinstance(raw_tool_configurations, Mapping): + return normalized + + existing_tool_parameters = normalized.get("tool_parameters") + normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {} + normalized_tool_configurations: dict[str, Any] = {} + found_legacy_tool_inputs = False + + for name, value in raw_tool_configurations.items(): + if not isinstance(value, Mapping): + normalized_tool_configurations[name] = value + continue + + input_type = value.get("type") + input_value = value.get("value") + if input_type not in {"mixed", "variable", "constant"}: + normalized_tool_configurations[name] = value + continue + + found_legacy_tool_inputs = True + normalized_tool_parameters.setdefault(name, dict(value)) + + flattened_value = _flatten_legacy_tool_configuration_value( + input_type=input_type, + input_value=input_value, + ) + if flattened_value is not None: + normalized_tool_configurations[name] = flattened_value + + if not found_legacy_tool_inputs: + return normalized + + normalized["tool_parameters"] = normalized_tool_parameters + normalized["tool_configurations"] = normalized_tool_configurations + return normalized + + +def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None: + if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool): + return input_value + + if ( + input_type == "variable" + and isinstance(input_value, list) + and all(isinstance(item, str) for item in input_value) + ): + return "{{#" + ".".join(input_value) + "#}}" + + return None + + def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: normalized = dict(recipients) @@ -290,9 +349,9 @@ __all__ = [ "MemberRecipient", "WebAppDeliveryMethod", "_WebAppDeliveryConfig", + "adapt_human_input_node_data_for_graph", + "adapt_node_config_for_graph", + "adapt_node_data_for_graph", "is_human_input_webapp_enabled", - "normalize_human_input_node_data_for_graph", - "normalize_node_config_for_graph", - "normalize_node_data_for_graph", "parse_human_input_delivery_methods", ] diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py index f124b321d4..b02f69ec33 100644 --- a/api/core/workflow/human_input_forms.py +++ b/api/core/workflow/human_input_forms.py @@ -12,20 +12,16 @@ from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session +from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token from extensions.ext_database import db from models.human_input import HumanInputFormRecipient, RecipientType -_FORM_TOKEN_PRIORITY = { - RecipientType.BACKSTAGE: 0, - RecipientType.CONSOLE: 1, - RecipientType.STANDALONE_WEB_APP: 2, -} - def load_form_tokens_by_form_id( form_ids: Sequence[str], *, session: Session | None = None, + surface: HumanInputSurface | None = None, ) -> dict[str, str]: """Load the preferred access token for each human input form.""" unique_form_ids = list(dict.fromkeys(form_ids)) @@ -33,23 +29,43 @@ def load_form_tokens_by_form_id( return {} if session is not None: - return _load_form_tokens_by_form_id(session, unique_form_ids) + return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface) with Session(bind=db.engine, expire_on_commit=False) as new_session: - return _load_form_tokens_by_form_id(new_session, unique_form_ids) + return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface) -def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: - tokens_by_form_id: dict[str, tuple[int, str]] = {} +def _load_form_tokens_by_form_id( + session: Session, + form_ids: Sequence[str], + *, + surface: HumanInputSurface | None = None, +) -> dict[str, str]: + recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {} stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) for recipient in session.scalars(stmt): - priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) - if priority is None or not recipient.access_token: + if not recipient.access_token: continue + recipients_by_form_id.setdefault(recipient.form_id, []).append( + (recipient.recipient_type, recipient.access_token) + ) - candidate = (priority, recipient.access_token) - current = tokens_by_form_id.get(recipient.form_id) - if current is None or candidate[0] < current[0]: - tokens_by_form_id[recipient.form_id] = candidate + tokens_by_form_id: dict[str, str] = {} + for form_id, recipients in recipients_by_form_id.items(): + token = _get_surface_form_token(recipients, surface=surface) + if token is not None: + tokens_by_form_id[form_id] = token + return tokens_by_form_id - return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} + +def _get_surface_form_token( + recipients: Sequence[tuple[RecipientType, str]], + *, + surface: HumanInputSurface | None, +) -> str | None: + if surface == HumanInputSurface.SERVICE_API: + for recipient_type, token in recipients: + if recipient_type == RecipientType.STANDALONE_WEB_APP and token: + return token + + return get_preferred_form_token(recipients) diff --git a/api/core/workflow/human_input_policy.py b/api/core/workflow/human_input_policy.py new file mode 100644 index 0000000000..798eb8723f --- /dev/null +++ b/api/core/workflow/human_input_policy.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any + +from graphon.entities.pause_reason import PauseReasonType +from models.human_input import RecipientType + + +class HumanInputSurface(StrEnum): + SERVICE_API = "service_api" + CONSOLE = "console" + + +# Service API is intentionally narrower than other surfaces: app-token callers +# should only be able to act on end-user web forms, not internal console flows. +_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = { + HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}), + HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}), +} + +# A single HITL form can have multiple recipient records; this shared priority +# keeps every API surface consistent about which resume token to expose. +_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def is_recipient_type_allowed_for_surface( + recipient_type: RecipientType | None, + surface: HumanInputSurface, +) -> bool: + if recipient_type is None: + return False + return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface] + + +def get_preferred_form_token( + recipients: Sequence[tuple[RecipientType, str]], +) -> str | None: + chosen_token: str | None = None + chosen_priority: int | None = None + for recipient_type, token in recipients: + priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type) + if priority is None or not token: + continue + if chosen_priority is None or priority < chosen_priority: + chosen_priority = priority + chosen_token = token + return chosen_token + + +def enrich_human_input_pause_reasons( + reasons: Sequence[Mapping[str, Any]], + *, + form_tokens_by_form_id: Mapping[str, str], + expiration_times_by_form_id: Mapping[str, int], +) -> list[dict[str, Any]]: + enriched: list[dict[str, Any]] = [] + for reason in reasons: + updated = dict(reason) + if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED: + form_id = updated.get("form_id") + if isinstance(form_id, str): + updated["form_token"] = form_tokens_by_form_id.get(form_id) + expiration_time = expiration_times_by_form_id.get(form_id) + if expiration_time is not None: + updated["expiration_time"] = expiration_time + enriched.append(updated) + return enriched diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index b04ac7da3d..de4eae1b22 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -5,22 +5,6 @@ from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, cast, final, override -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.file.file_manager import file_manager -from graphon.graph.graph import NodeFactory -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.base.node import Node -from graphon.nodes.code.code_node import WorkflowCodeExecutor -from graphon.nodes.code.entities import CodeLanguage -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.nodes.document_extractor import UnstructuredApiConfig -from graphon.nodes.http_request import build_http_request_config -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session @@ -31,12 +15,12 @@ from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, ) -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.trigger.constants import TRIGGER_NODE_TYPES -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_runtime import ( DifyFileReferenceFactory, DifyHumanInputNodeRuntime, @@ -56,6 +40,22 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: @@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + """Resolve the production node class for the requested type/version.""" node_mapping = get_node_type_classes_mapping().get(node_type) if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") @@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory): ) self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH - self._http_request_http_client = ssrf_proxy + self._http_request_http_client = graphon_ssrf_proxy self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( self._dify_context, conversation_id_getter=self._conversation_id, @@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + # Graph configs are initially validated against permissive shared node data. + # Re-validate using the resolved node class so workflow-local node schemas + # stay explicit and constructors receive the concrete typed payload. + resolved_node_data = self._validate_resolved_node_data(node_class, node_data) node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory): ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=False, include_llm_file_saver=False, @@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory): } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, @@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory): """ Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. """ - return node_class.validate_node_data(node_data) + validate_node_data = getattr(node_class, "validate_node_data", None) + if callable(validate_node_data): + return cast("BaseNodeData", validate_node_data(node_data)) + return node_data @staticmethod def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 19cb3a7b0a..b8725853c4 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -2,38 +2,8 @@ from __future__ import annotations from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities import LLMMode -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.nodes.llm.runtime_protocols import ( - PreparedLLMProtocol, - PromptMessageSerializerProtocol, - RetrieverAttachmentLoaderProtocol, -) -from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol -from graphon.nodes.runtime import ( - HumanInputFormStateProtocol, - HumanInputNodeRuntimeProtocol, - ToolNodeRuntimeProtocol, -) -from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) from sqlalchemy import select from sqlalchemy.orm import Session @@ -60,11 +30,41 @@ from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories import file_factory +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities import LLMMode +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from .human_input_compat import ( +from .human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, DeliveryMethodType, @@ -76,13 +76,12 @@ from .human_input_compat import ( from .system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage from graphon.file import File from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.tool.entities import ToolNodeData - from core.tools.__base.tool import Tool - from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage - _file_access_controller = DatabaseFileAccessController() @@ -174,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: return self._model_instance.get_llm_num_tokens(prompt_messages) + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -191,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): stream=stream, ) + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + def invoke_llm_with_structured_output( self, *, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index bfd5536e4a..68a24e86b1 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,15 +3,13 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.workflow.system_variables import SystemVariableKey, get_system_text - from .entities import AgentNodeData from .exceptions import ( AgentInvocationError, @@ -36,18 +34,18 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: AgentNodeData, + *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - *, strategy_resolver: AgentStrategyResolver, presentation_provider: AgentStrategyPresentationProvider, runtime_support: AgentRuntimeSupport, message_transformer: AgentMessageTransformer, ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index c52aad150b..51452c29a3 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index db74590ed7..f44681377d 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,6 +3,14 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from extensions.ext_database import db +from factories import file_factory from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata @@ -15,14 +23,6 @@ from graphon.node_events import ( StreamCompletedEvent, ) from graphon.variables.segments import ArrayFileSegment -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from extensions.ext_database import db -from factories import file_factory from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index be50edbc4d..a872774c98 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,8 +4,6 @@ import json from collections.abc import Sequence from typing import Any, cast -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -21,6 +19,8 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolP from core.tools.tool_manager import ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d9247b2593..f3006c4242 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,7 +1,12 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, @@ -12,13 +17,6 @@ from graphon.node_events import NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.system_variables import SystemVariableKey, get_system_segment - from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError @@ -37,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: DatasourceNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index cad32f8d5b..28966f2392 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,9 +1,10 @@ from typing import Any, Literal, Union +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 04a10f9257..260881e49c 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,13 +1,13 @@ from typing import Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel from core.rag.entities import RerankingModelConfig, WeightedScoreConfig from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class RetrievalSetting(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index bb72fe3881..9c1b7ab2c4 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,15 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template - from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( @@ -33,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeIndexNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: - super().__init__(id, config, graph_init_params, graph_runtime_state) + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) self.index_processor = IndexProcessor() self.summary_index_service = SummaryIndex() diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 460ec693ce..3825f526a2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,11 +1,11 @@ from typing import Literal -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.nodes.llm.entities import ModelConfig, VisionConfig from pydantic import BaseModel, Field from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig __all__ = ["Condition"] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 13624b27b3..25f73e446d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,8 +8,12 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, @@ -27,12 +31,6 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference - from .entities import ( Condition, KnowledgeRetrievalNodeData, @@ -51,6 +49,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None: + if value is None or isinstance(value, (str, float)): + return value + if isinstance(value, int) and not isinstance(value, bool): + return value + return str(value) + + +def _normalize_metadata_filter_sequence_item(value: object) -> str: + return value if isinstance(value, str) else str(value) + + class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL @@ -60,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeRetrievalNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -283,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD resolved_conditions: list[Condition] = [] for cond in conditions.conditions or []: value = cond.value + resolved_value: str | Sequence[str] | int | float | None if isinstance(value, str): segment_group = variable_pool.convert_template(value) if len(segment_group.value) == 1: - resolved_value = segment_group.value[0].to_object() + resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) else: resolved_value = segment_group.text elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): - resolved_values = [] - for v in value: # type: ignore + resolved_values: list[str] = [] + for v in value: segment_group = variable_pool.convert_template(v) if len(segment_group.value) == 1: - resolved_values.append(segment_group.value[0].to_object()) + resolved_values.append( + _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) + ) else: resolved_values.append(segment_group.text) resolved_value = resolved_values diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index 39e2008a2c..ea45dcf5c2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol -from graphon.model_runtime.entities import LLMUsage -from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index bf5be2379a..23ed2cd408 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index e50de11bb9..c848a86255 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,13 +1,12 @@ from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID - from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 04f1f7e6bb..683c8d420f 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index a9753ab387..b46cc76a6e 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,10 @@ from collections.abc import Mapping -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node - from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index a30f877e4b..b261039448 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES: frozenset[SegmentType] = frozenset((SegmentType.STRING,)) diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index d942a718cc..13c4f05bfd 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,6 +2,10 @@ import logging from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult @@ -10,11 +14,6 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type - from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py index d51cfadd09..b4ffb37549 100644 --- a/api/core/workflow/template_rendering.py +++ b/api/core/workflow/template_rendering.py @@ -3,11 +3,10 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from graphon.nodes.code.entities import CodeLanguage from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor - class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index f0a5fbb400..4e2f603e5b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -3,20 +3,6 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import Any, TypedDict -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file import File -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel -from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool - from configs import dify_config from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError @@ -40,6 +26,19 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow logger = logging.getLogger(__name__) diff --git a/api/dev/generate_swagger_specs.py b/api/dev/generate_swagger_specs.py new file mode 100644 index 0000000000..7e9688bfb4 --- /dev/null +++ b/api/dev/generate_swagger_specs.py @@ -0,0 +1,172 @@ +"""Generate Flask-RESTX Swagger 2.0 specs without booting the full backend. + +This helper intentionally avoids `app_factory.create_app()`. The normal backend +startup eagerly initializes database, Redis, Celery, and storage extensions, +which is unnecessary when the goal is only to serialize the Flask-RESTX +`/swagger.json` documents. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from dataclasses import dataclass +from pathlib import Path + +from flask import Flask +from flask_restx.swagger import Swagger + +logger = logging.getLogger(__name__) + +API_ROOT = Path(__file__).resolve().parents[1] +if str(API_ROOT) not in sys.path: + sys.path.insert(0, str(API_ROOT)) + + +@dataclass(frozen=True) +class SpecTarget: + route: str + filename: str + + +SPEC_TARGETS: tuple[SpecTarget, ...] = ( + SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"), + SpecTarget(route="/api/swagger.json", filename="web-swagger.json"), + SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"), +) + +_ORIGINAL_REGISTER_MODEL = Swagger.register_model +_ORIGINAL_REGISTER_FIELD = Swagger.register_field + + +def _apply_runtime_defaults() -> None: + """Force the small config surface required for Swagger generation.""" + + os.environ.setdefault("SECRET_KEY", "spec-export") + os.environ.setdefault("STORAGE_TYPE", "local") + os.environ.setdefault("STORAGE_LOCAL_PATH", "/tmp/dify-storage") + os.environ.setdefault("SWAGGER_UI_ENABLED", "true") + + from configs import dify_config + + dify_config.SECRET_KEY = os.environ["SECRET_KEY"] + dify_config.STORAGE_TYPE = "local" + dify_config.STORAGE_LOCAL_PATH = os.environ["STORAGE_LOCAL_PATH"] + dify_config.SWAGGER_UI_ENABLED = os.environ["SWAGGER_UI_ENABLED"].lower() == "true" + + +def _patch_swagger_for_inline_nested_dicts() -> None: + """Teach Flask-RESTX Swagger generation to tolerate inline nested field maps. + + Some existing controllers use `fields.Nested({...})` with a raw field mapping + instead of a named `api.model(...)`. Flask-RESTX crashes on those anonymous + dicts during schema registration, so this helper upgrades them into temporary + named models at export time. + """ + + if getattr(Swagger, "_dify_inline_nested_dict_patch", False): + return + + def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object: + anonymous_models = getattr(self, "_anonymous_inline_models", None) + if anonymous_models is None: + anonymous_models = {} + self._anonymous_inline_models = anonymous_models + + anonymous_name = anonymous_models.get(id(nested_fields)) + if anonymous_name is None: + anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}" + anonymous_models[id(nested_fields)] = anonymous_name + self.api.model(anonymous_name, nested_fields) + + return self.api.models[anonymous_name] + + def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]: + if isinstance(model, dict): + model = get_or_create_inline_model(self, model) + + return _ORIGINAL_REGISTER_MODEL(self, model) + + def register_field_with_inline_dict_support(self: Swagger, field: object) -> None: + nested = getattr(field, "nested", None) + if isinstance(nested, dict): + field.model = get_or_create_inline_model(self, nested) # type: ignore + + _ORIGINAL_REGISTER_FIELD(self, field) + + Swagger.register_model = register_model_with_inline_dict_support + Swagger.register_field = register_field_with_inline_dict_support + Swagger._dify_inline_nested_dict_patch = True + + +def create_spec_app() -> Flask: + """Build a minimal Flask app that only mounts the Swagger-producing blueprints.""" + + _apply_runtime_defaults() + _patch_swagger_for_inline_nested_dicts() + + app = Flask(__name__) + + from controllers.console import bp as console_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + app.register_blueprint(console_bp) + app.register_blueprint(web_bp) + app.register_blueprint(service_api_bp) + + return app + + +def generate_specs(output_dir: Path) -> list[Path]: + """Write all Swagger specs to `output_dir` and return the written paths.""" + + output_dir.mkdir(parents=True, exist_ok=True) + + app = create_spec_app() + client = app.test_client() + + written_paths: list[Path] = [] + for target in SPEC_TARGETS: + response = client.get(target.route) + if response.status_code != 200: + raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}") + + payload = response.get_json() + if not isinstance(payload, dict): + raise RuntimeError(f"unexpected response payload for {target.route}") + + output_path = output_dir / target.filename + output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + written_paths.append(output_path) + + return written_paths + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-o", + "--output-dir", + type=Path, + default=Path("openapi"), + help="Directory where the Swagger JSON files will be written.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + written_paths = generate_specs(args.output_dir) + + for path in written_paths: + logger.debug(path) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6b904b7d0d..fc118df5bc 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -119,14 +119,16 @@ elif [[ "${MODE}" == "job" ]]; then else if [[ "${DEBUG}" == "true" ]]; then - exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0} + export PORT=${DIFY_PORT:-5001} + exec python -m app else exec gunicorn \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ - --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \ --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ --timeout ${GUNICORN_TIMEOUT:-200} \ - app:app + app:socketio_app fi fi diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py index 5a8d0ee6f4..dff558988c 100644 --- a/api/enterprise/telemetry/draft_trace.py +++ b/api/enterprise/telemetry/draft_trace.py @@ -3,10 +3,9 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any -from graphon.enums import WorkflowNodeExecutionMetadataKey - from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName from core.telemetry import emit as telemetry_emit +from graphon.enums import WorkflowNodeExecutionMetadataKey from models.workflow import WorkflowNodeExecutionModel diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py index 9cda0bf90a..c564ace584 100644 --- a/api/enterprise/telemetry/metric_handler.py +++ b/api/enterprise/telemetry/metric_handler.py @@ -329,7 +329,7 @@ class EnterpriseMetricHandler: return include_content = exporter.include_content - attrs: dict = { + attrs: dict[str, Any] = { "dify.message.id": payload.get("message_id"), "dify.tenant_id": envelope.tenant_id, "dify.event.id": envelope.event_id, diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 9f511b88ef..a10ac21f69 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,56 +1,17 @@ -import logging -from dataclasses import dataclass from enum import StrEnum, auto -logger = logging.getLogger(__name__) - - -@dataclass -class QuotaCharge: - """ - Result of a quota consumption operation. - - Attributes: - success: Whether the quota charge succeeded - charge_id: UUID for refund, or None if failed/disabled - """ - - success: bool - charge_id: str | None - _quota_type: "QuotaType" - - def refund(self) -> None: - """ - Refund this quota charge. - - Safe to call even if charge failed or was disabled. - This method guarantees no exceptions will be raised. - """ - if self.charge_id: - self._quota_type.refund(self.charge_id) - logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id) - class QuotaType(StrEnum): """ Supported quota types for tenant feature usage. - - Add additional types here whenever new billable features become available. """ - # Trigger execution quota TRIGGER = auto() - - # Workflow execution quota WORKFLOW = auto() - UNLIMITED = auto() @property def billing_key(self) -> str: - """ - Get the billing key for the feature. - """ match self: case QuotaType.TRIGGER: return "trigger_event" @@ -58,152 +19,3 @@ class QuotaType(StrEnum): return "api_rate_limit" case _: raise ValueError(f"Invalid quota type: {self}") - - def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge: - """ - Consume quota for the feature. - - Args: - tenant_id: The tenant identifier - amount: Amount to consume (default: 1) - - Returns: - QuotaCharge with success status and charge_id for refund - - Raises: - QuotaExceededError: When quota is insufficient - """ - from configs import dify_config - from services.billing_service import BillingService - from services.errors.app import QuotaExceededError - - if not dify_config.BILLING_ENABLED: - logger.debug("Billing disabled, allowing request for %s", tenant_id) - return QuotaCharge(success=True, charge_id=None, _quota_type=self) - - logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id) - - if amount <= 0: - raise ValueError("Amount to consume must be greater than 0") - - try: - response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount) - - if response.get("result") != "success": - logger.warning( - "Failed to consume quota for %s, feature %s details: %s", - tenant_id, - self.value, - response.get("detail"), - ) - raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount) - - charge_id = response.get("history_id") - logger.debug( - "Successfully consumed %d %s quota for tenant %s, charge_id: %s", - amount, - self.value, - tenant_id, - charge_id, - ) - return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self) - - except QuotaExceededError: - raise - except Exception: - # fail-safe: allow request on billing errors - logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value) - return unlimited() - - def check(self, tenant_id: str, amount: int = 1) -> bool: - """ - Check if tenant has sufficient quota without consuming. - - Args: - tenant_id: The tenant identifier - amount: Amount to check (default: 1) - - Returns: - True if quota is sufficient, False otherwise - """ - from configs import dify_config - - if not dify_config.BILLING_ENABLED: - return True - - if amount <= 0: - raise ValueError("Amount to check must be greater than 0") - - try: - remaining = self.get_remaining(tenant_id) - return remaining >= amount if remaining != -1 else True - except Exception: - logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value) - # fail-safe: allow request on billing errors - return True - - def refund(self, charge_id: str) -> None: - """ - Refund quota using charge_id from consume(). - - This method guarantees no exceptions will be raised. - All errors are logged but silently handled. - - Args: - charge_id: The UUID returned from consume() - """ - try: - from configs import dify_config - from services.billing_service import BillingService - - if not dify_config.BILLING_ENABLED: - return - - if not charge_id: - logger.warning("Cannot refund: charge_id is empty") - return - - logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id) - - response = BillingService.refund_tenant_feature_plan_usage(charge_id) - if response.get("result") == "success": - logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id) - else: - logger.warning("Refund failed for charge_id: %s", charge_id) - - except Exception: - # Catch ALL exceptions - refund must never fail - logger.exception("Failed to refund quota for charge_id: %s", charge_id) - # Don't raise - refund is best-effort and must be silent - - def get_remaining(self, tenant_id: str) -> int: - """ - Get remaining quota for the tenant. - - Args: - tenant_id: The tenant identifier - - Returns: - Remaining quota amount - """ - from services.billing_service import BillingService - - try: - usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key) - # Assuming the API returns a dict with 'remaining' or 'limit' and 'used' - if isinstance(usage_info, dict): - return usage_info.get("remaining", 0) - # If it returns a simple number, treat it as remaining - return int(usage_info) if usage_info else 0 - except Exception: - logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) - return -1 - - -def unlimited() -> QuotaCharge: - """ - Return a quota charge for unlimited quota. - - This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type. - """ - return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index b7e7a6e60f..0c535a1c5b 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -6,9 +6,9 @@ import click from sqlalchemy import select from werkzeug.exceptions import NotFound +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.document_index_event import document_index_created -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document from models.enums import IndexingStatus @@ -22,24 +22,25 @@ def handle(sender, **kwargs): document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) + with session_factory.create_session() as session: + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = db.session.scalar( - select(Document).where( - Document.id == document_id, - Document.dataset_id == dataset_id, + document = session.scalar( + select(Document).where( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) ) - ) - if not document: - raise NotFound("Document not found") + if not document: + raise NotFound("Document not found") - document.indexing_status = IndexingStatus.PARSING - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() + document.indexing_status = IndexingStatus.PARSING + document.processing_started_at = naive_utc_now() + documents.append(document) + session.add(document) + session.commit() with contextlib.suppress(Exception): try: diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 57412cc4ad..38e102d5fd 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.model import InstalledApp @@ -12,5 +12,6 @@ def handle(sender, **kwargs): app_id=app.id, app_owner_tenant_id=app.tenant_id, ) - db.session.add(installed_app) - db.session.commit() + with session_factory.create_session() as session: + session.add(installed_app) + session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 84be592b1a..5e2a456dce 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.enums import CustomizeTokenStrategy from models.model import Site @@ -22,6 +22,6 @@ def handle(sender, **kwargs): created_by=app.created_by, updated_by=app.updated_by, ) - - db.session.add(site) - db.session.commit() + with session_factory.create_session() as session: + session.add(site) + session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 7bd8e88231..ba9758175f 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,12 +1,11 @@ import logging -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity - from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_draft_workflow_was_synced +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 86b5b2bbf0..6769b94cde 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 86b0550187..340f514fcc 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -9,6 +9,7 @@ from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import normalize_redis_key_prefix class _CelerySentinelKwargsDict(TypedDict): @@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict): password: str | None -class CelerySentinelTransportDict(TypedDict): +class CelerySentinelTransportDict(TypedDict, total=False): master_name: str | None sentinel_kwargs: _CelerySentinelKwargsDict + global_keyprefix: str class CelerySSLOptionsDict(TypedDict): @@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None: def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" + transport_options: CelerySentinelTransportDict | dict[str, Any] if dify_config.CELERY_USE_SENTINEL: - return CelerySentinelTransportDict( + transport_options = CelerySentinelTransportDict( master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, sentinel_kwargs=_CelerySentinelKwargsDict( socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, password=dify_config.CELERY_SENTINEL_PASSWORD, ), ) - return {} + else: + transport_options = {} + + global_keyprefix = get_celery_redis_global_keyprefix() + if global_keyprefix: + transport_options["global_keyprefix"] = global_keyprefix + + return transport_options + + +def get_celery_redis_global_keyprefix() -> str | None: + """Return the Redis transport prefix for Celery when namespace isolation is enabled.""" + normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + if not normalized_prefix: + return None + return f"{normalized_prefix}:" def init_app(app: DifyApp) -> Celery: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 20f05b8b9e..9f7f73765e 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import Any, Union, cast import redis from redis import RedisError @@ -18,17 +18,26 @@ from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import ( + normalize_redis_key_prefix, + serialize_redis_name, + serialize_redis_name_arg, + serialize_redis_name_args, +) from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel -if TYPE_CHECKING: - from redis.lock import Lock - logger = logging.getLogger(__name__) +_normalize_redis_key_prefix = normalize_redis_key_prefix +_serialize_redis_name = serialize_redis_name +_serialize_redis_name_arg = serialize_redis_name_arg +_serialize_redis_name_args = serialize_redis_name_args + + class RedisClientWrapper: """ A wrapper class for the Redis client that addresses the issue where the global @@ -59,68 +68,148 @@ class RedisClientWrapper: if self._client is None: self._client = client - if TYPE_CHECKING: - # Type hints for IDE support and static analysis - # These are not executed at runtime but provide type information - def get(self, name: str | bytes) -> Any: ... - - def set( - self, - name: str | bytes, - value: Any, - ex: int | None = None, - px: int | None = None, - nx: bool = False, - xx: bool = False, - keepttl: bool = False, - get: bool = False, - exat: int | None = None, - pxat: int | None = None, - ) -> Any: ... - - def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... - def setnx(self, name: str | bytes, value: Any) -> Any: ... - def delete(self, *names: str | bytes) -> Any: ... - def incr(self, name: str | bytes, amount: int = 1) -> Any: ... - def expire( - self, - name: str | bytes, - time: int | timedelta, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def lock( - self, - name: str, - timeout: float | None = None, - sleep: float = 0.1, - blocking: bool = True, - blocking_timeout: float | None = None, - thread_local: bool = True, - ) -> Lock: ... - def zadd( - self, - name: str | bytes, - mapping: dict[str | bytes | int | float, float | int | str | bytes], - nx: bool = False, - xx: bool = False, - ch: bool = False, - incr: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... - def zcard(self, name: str | bytes) -> Any: ... - def getdel(self, name: str | bytes) -> Any: ... - def pubsub(self) -> PubSub: ... - def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... - - def __getattr__(self, item: str) -> Any: + def _require_client(self) -> redis.Redis | RedisCluster: if self._client is None: raise RuntimeError("Redis client is not initialized. Call init_app first.") - return getattr(self._client, item) + return self._client + + def _get_prefix(self) -> str: + return dify_config.REDIS_KEY_PREFIX + + def get(self, name: str | bytes) -> Any: + return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix())) + + def set( + self, + name: str | bytes, + value: Any, + ex: int | None = None, + px: int | None = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: int | None = None, + pxat: int | None = None, + ) -> Any: + return self._require_client().set( + _serialize_redis_name_arg(name, self._get_prefix()), + value, + ex=ex, + px=px, + nx=nx, + xx=xx, + keepttl=keepttl, + get=get, + exat=exat, + pxat=pxat, + ) + + def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: + return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value) + + def setnx(self, name: str | bytes, value: Any) -> Any: + return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value) + + def delete(self, *names: str | bytes) -> Any: + return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix())) + + def incr(self, name: str | bytes, amount: int = 1) -> Any: + return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount) + + def expire( + self, + name: str | bytes, + time: int | timedelta, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().expire( + _serialize_redis_name_arg(name, self._get_prefix()), + time, + nx=nx, + xx=xx, + gt=gt, + lt=lt, + ) + + def exists(self, *names: str | bytes) -> Any: + return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix())) + + def ttl(self, name: str | bytes) -> Any: + return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix())) + + def getdel(self, name: str | bytes) -> Any: + return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix())) + + def lock( + self, + name: str, + timeout: float | None = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float | None = None, + thread_local: bool = True, + ) -> Any: + return self._require_client().lock( + _serialize_redis_name(name, self._get_prefix()), + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any: + return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs) + + def hgetall(self, name: str | bytes) -> Any: + return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix())) + + def hdel(self, name: str | bytes, *keys: str | bytes) -> Any: + return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys) + + def hlen(self, name: str | bytes) -> Any: + return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix())) + + def zadd( + self, + name: str | bytes, + mapping: dict[str | bytes | int | float, float | int | str | bytes], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().zadd( + _serialize_redis_name_arg(name, self._get_prefix()), + cast(Any, mapping), + nx=nx, + xx=xx, + ch=ch, + incr=incr, + gt=gt, + lt=lt, + ) + + def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: + return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max) + + def zcard(self, name: str | bytes) -> Any: + return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix())) + + def pubsub(self) -> PubSub: + return self._require_client().pubsub() + + def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: + return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint) + + def __getattr__(self, item: str) -> Any: + return getattr(self._require_client(), item) redis_client: RedisClientWrapper = RedisClientWrapper() diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 5cc58f27c4..69d1f1ab07 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,11 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from graphon.model_runtime.errors.invoke import InvokeRateLimitError from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException + from graphon.model_runtime.errors.invoke import InvokeRateLimitError + try: from langfuse._utils import parse_error diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py new file mode 100644 index 0000000000..5ed82bac8d --- /dev/null +++ b/api/extensions/ext_socketio.py @@ -0,0 +1,5 @@ +import socketio # type: ignore[reportMissingTypeStubs] + +from configs import dify_config + +sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS) diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index db599c5d49..64ff0f0674 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value +from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 2745141431..7f77a0437a 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string +from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index d0f3e2e244..544109276d 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -3,14 +3,14 @@ import logging import os import time -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 37952d6464..dc7654a25c 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,10 +13,6 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -26,6 +22,10 @@ from core.repositories.factory import OrderConfig, WorkflowNodeExecutionReposito from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 23d324f9ea..fbf379b3e5 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -10,17 +10,17 @@ Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when import json from typing import Any, Protocol -from graphon.enums import BuiltinNodeTypes -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment def should_include_content() -> bool: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 335c5cc29e..ec3c78a12d 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 6df5f62c15..56672d1fd4 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index b9fdd9e1ca..75ddbba448 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.nodes.tool.entities import ToolNodeData from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData class ToolNodeOTelParser: diff --git a/api/extensions/redis_names.py b/api/extensions/redis_names.py new file mode 100644 index 0000000000..9e63416daf --- /dev/null +++ b/api/extensions/redis_names.py @@ -0,0 +1,32 @@ +from configs import dify_config + + +def normalize_redis_key_prefix(prefix: str | None) -> str: + """Normalize the configured Redis key prefix for consistent runtime use.""" + if prefix is None: + return "" + return prefix.strip() + + +def get_redis_key_prefix() -> str: + """Read and normalize the current Redis key prefix from config.""" + return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + + +def serialize_redis_name(name: str, prefix: str | None = None) -> str: + """Convert a logical Redis name into the physical name used in Redis.""" + normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix) + if not normalized_prefix: + return name + return f"{normalized_prefix}:{name}" + + +def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes: + """Prefix string Redis names while preserving bytes inputs unchanged.""" + if isinstance(name, bytes): + return name + return serialize_redis_name(name, prefix) + + +def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]: + return tuple(serialize_redis_name_arg(name, prefix) for name in names) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 96f5915ff0..cd7f7db295 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -2,6 +2,7 @@ import logging import os from collections.abc import Generator from pathlib import Path +from typing import Any import opendal from dotenv import dotenv_values @@ -19,7 +20,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str if key.startswith(config_prefix): kwargs[key[len(config_prefix) :].lower()] = value - file_env_vars: dict = dotenv_values(env_file_path) or {} + file_env_vars: dict[str, Any] = dotenv_values(env_file_path) or {} for key, value in file_env_vars.items(): if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: kwargs[key[len(config_prefix) :].lower()] = value diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index 7516d18c8e..1d2ad4d445 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -7,12 +7,12 @@ import uuid from collections.abc import Mapping, Sequence from typing import Any -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from models import ToolFile, UploadFile from .common import resolve_mapping_file_id @@ -135,29 +135,30 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if row is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + row = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type", "custom"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - reference=build_file_reference(record_id=str(row.id)), - size=row.size, - storage_key=row.key, - ) + return File( + file_id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) def _build_from_remote_url( @@ -179,32 +180,33 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if upload_file is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type( - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - ) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - reference=build_file_reference(record_id=str(upload_file.id)), - size=upload_file.size, - storage_key=upload_file.key, - ) + return File( + file_id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -220,9 +222,9 @@ def _build_from_remote_url( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=filename, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -247,30 +249,31 @@ def _build_from_tool_file( ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) - tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") + with session_factory.create_session() as session: + tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - reference=build_file_reference(record_id=str(tool_file.id)), - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) + return File( + file_id=mapping.get("id"), + filename=tool_file.name, + file_type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) def _build_from_datasource_file( @@ -289,31 +292,32 @@ def _build_from_datasource_file( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + with session_factory.create_session() as session: + datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("datasource_file_id"), - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - reference=build_file_reference(record_id=str(datasource_file.id)), - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) + return File( + file_id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + file_type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 5582b85c95..4b3d514238 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -4,9 +4,8 @@ from __future__ import annotations from collections.abc import Sequence -from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig - from core.app.file_access import FileAccessControllerProtocol +from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig from models import MessageFile from .builders import build_from_mapping diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py index db3a7f3015..dba4c84407 100644 --- a/api/factories/file_factory/storage_keys.py +++ b/api/factories/file_factory/storage_keys.py @@ -5,12 +5,12 @@ from __future__ import annotations import uuid from collections.abc import Mapping, Sequence -from graphon.file import File, FileTransferMethod from sqlalchemy import select from sqlalchemy.orm import Session from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference, parse_file_reference +from graphon.file import File, FileTransferMethod from models import ToolFile, UploadFile diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 57205b5739..fd7acb14d3 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,6 +8,11 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -31,12 +36,6 @@ from graphon.variables.variables import ( VariableBase, ) -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) - __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index b5acbbbcb4..d518114777 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False): def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): - return v.value_type.exposed_type().value + return str(v.value_type.exposed_type()) else: value_type = v.get("value_type") if value_type is None: raise ValueError("value_type is required but not provided") - return value_type.exposed_type().value + return str(value_type.exposed_type()) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1afcbdb5b9..bf5c9ffcb1 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,10 +3,10 @@ from __future__ import annotations from datetime import datetime from typing import Any -from graphon.file import File from pydantic import Field, field_validator, model_validator from fields.base import ResponseModel +from graphon.file import File type JSONValue = Any @@ -96,7 +96,7 @@ class ConversationAnnotation(ResponseModel): class ConversationAnnotationHitHistory(ResponseModel): - annotation_id: str + annotation_id: str = Field(validation_alias="id") annotation_create_account: SimpleAccount | None = None created_at: int | None = None @@ -143,7 +143,7 @@ class MessageDetail(ResponseModel): query: str message: JSONValue message_tokens: int - answer: str + answer: str = Field(validation_alias="re_sign_file_url_answer") answer_tokens: int provider_response_latency: float from_source: str @@ -156,7 +156,7 @@ class MessageDetail(ResponseModel): created_at: int | None = None agent_thoughts: list[AgentThought] message_files: list[MessageFile] - metadata: JSONValue + metadata: JSONValue = Field(validation_alias="message_metadata_dict") status: str error: str | None = None parent_message_id: str | None = None @@ -196,7 +196,7 @@ class ModelConfig(ResponseModel): class SimpleModelConfig(ResponseModel): - model: JSONValue | None = None + model: JSONValue | None = Field(default=None, validation_alias="model_dict") pre_prompt: str | None = None @@ -211,6 +211,11 @@ class SimpleMessageDetail(ResponseModel): def _normalize_inputs(cls, value: JSONValue) -> JSONValue: return format_files_contained(value) + @field_validator("message", mode="before") + @classmethod + def _normalize_message(cls, value: JSONValue) -> str: + return message_text(value) + class Conversation(ResponseModel): id: str @@ -227,15 +232,22 @@ class Conversation(ResponseModel): model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config") user_feedback_stats: FeedbackStat | None = None admin_feedback_stats: FeedbackStat | None = None - message: SimpleMessageDetail | None = None + message: SimpleMessageDetail | None = Field(default=None, validation_alias="first_message") + + @field_validator("read_at", "created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value class ConversationPagination(ResponseModel): page: int - limit: int + limit: int = Field(validation_alias="per_page") total: int - has_more: bool - data: list[Conversation] + has_more: bool = Field(validation_alias="has_next") + data: list[Conversation] = Field(validation_alias="items") class ConversationMessageDetail(ResponseModel): @@ -246,7 +258,14 @@ class ConversationMessageDetail(ResponseModel): from_account_id: str | None = None created_at: int | None = None model_config_: ModelConfig | None = Field(default=None, alias="model_config") - message: MessageDetail | None = None + message: MessageDetail | None = Field(default=None, validation_alias="first_message") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value class ConversationWithSummary(ResponseModel): @@ -258,7 +277,7 @@ class ConversationWithSummary(ResponseModel): from_account_id: str | None = None from_account_name: str | None = None name: str - summary: str + summary: str = Field(validation_alias="summary_or_query") read_at: int | None = None created_at: int | None = None updated_at: int | None = None @@ -269,13 +288,20 @@ class ConversationWithSummary(ResponseModel): admin_feedback_stats: FeedbackStat | None = None status_count: StatusCount | None = None + @field_validator("read_at", "created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + class ConversationWithSummaryPagination(ResponseModel): page: int - limit: int + limit: int = Field(validation_alias="per_page") total: int - has_more: bool - data: list[ConversationWithSummary] + has_more: bool = Field(validation_alias="has_next") + data: list[ConversationWithSummary] = Field(validation_alias="items") class ConversationDetail(ResponseModel): @@ -293,6 +319,13 @@ class ConversationDetail(ResponseModel): user_feedback_stats: FeedbackStat | None = None admin_feedback_stats: FeedbackStat | None = None + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + def to_timestamp(value: datetime | None) -> int | None: if value is None: diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index c55014a368..e4219ba1ee 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,5 +1,13 @@ -from flask_restx import Namespace, fields +from __future__ import annotations +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import field_validator + +from fields.base import ResponseModel +from graphon.variables.types import SegmentType from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type @@ -29,6 +37,74 @@ conversation_variable_infinite_scroll_pagination_fields = { } +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type()) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + def build_conversation_variable_model(api_or_ns: Namespace): """Build the conversation variable model for the API or Namespace.""" return api_or_ns.model("ConversationVariable", conversation_variable_fields) diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index cfe0015918..67b320beaa 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,10 +3,10 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from graphon.file import helpers as file_helpers from pydantic import computed_field, field_validator from fields.base import ResponseModel +from graphon.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 1a871204a0..ca18f1c203 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,12 +3,12 @@ from __future__ import annotations from datetime import datetime from uuid import uuid4 -from graphon.file import File from pydantic import Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from fields.base import ResponseModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile +from graphon.file import File type JSONValueType = JSONValue diff --git a/api/fields/online_user_fields.py b/api/fields/online_user_fields.py new file mode 100644 index 0000000000..bdbe19679c --- /dev/null +++ b/api/fields/online_user_fields.py @@ -0,0 +1,16 @@ +from flask_restx import fields + +online_user_partial_fields = { + "user_id": fields.String, + "username": fields.String, + "avatar": fields.String, +} + +workflow_online_users_fields = { + "app_id": fields.String, + "users": fields.List(fields.Nested(online_user_partial_fields)), +} + +online_user_list_fields = { + "data": fields.List(fields.Nested(workflow_online_users_fields)), +} diff --git a/api/fields/raws.py b/api/fields/raws.py index 4c65cdab7a..ee6f53b360 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,5 @@ from flask_restx import fields + from graphon.file import File diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index d0e762f62b..1b2c71255d 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,8 +1,17 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import field_validator + +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser, simple_end_user_fields +from fields.member_fields import SimpleAccount, simple_account_fields from fields.workflow_run_fields import ( + WorkflowRunForArchivedLogResponse, + WorkflowRunForLogResponse, build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, workflow_run_for_archived_log_fields, @@ -85,3 +94,55 @@ def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): copied_fields = workflow_archived_log_pagination_fields.copy() copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) + + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: Any = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowArchivedLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForArchivedLogResponse | None = None + trigger_metadata: Any = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +class WorkflowArchivedLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowArchivedLogPartialResponse] diff --git a/api/fields/workflow_comment_fields.py b/api/fields/workflow_comment_fields.py new file mode 100644 index 0000000000..c708dd3460 --- /dev/null +++ b/api/fields/workflow_comment_fields.py @@ -0,0 +1,96 @@ +from flask_restx import fields + +from libs.helper import AvatarUrlField, TimestampField + +# basic account fields for comments +account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, + "avatar_url": AvatarUrlField, +} + +# Comment mention fields +workflow_comment_mention_fields = { + "mentioned_user_id": fields.String, + "mentioned_user_account": fields.Nested(account_fields, allow_null=True), + "reply_id": fields.String, +} + +# Comment reply fields +workflow_comment_reply_fields = { + "id": fields.String, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, +} + +# Basic comment fields (for list views) +workflow_comment_basic_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "reply_count": fields.Integer, + "mention_count": fields.Integer, + "participants": fields.List(fields.Nested(account_fields)), +} + +# Detailed comment fields (for single comment view) +workflow_comment_detail_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "replies": fields.List(fields.Nested(workflow_comment_reply_fields)), + "mentions": fields.List(fields.Nested(workflow_comment_mention_fields)), +} + +# Comment creation response fields (simplified) +workflow_comment_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Comment update response fields (simplified) +workflow_comment_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} + +# Comment resolve response fields +workflow_comment_resolve_fields = { + "id": fields.String, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, +} + +# Reply creation response fields (simplified) +workflow_comment_reply_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Reply update response fields +workflow_comment_reply_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index b0b6cc0b48..6e947858ba 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type @@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.exposed_type().value, + "value_type": str(value.value_type.exposed_type()), "description": value.description, } if isinstance(value, dict): diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 35bb442c59..8c659086ed 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,7 +1,14 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import Field, field_validator + +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser, simple_end_user_fields +from fields.member_fields import SimpleAccount, simple_account_fields from libs.helper import TimestampField workflow_run_for_log_fields = { @@ -147,3 +154,174 @@ workflow_run_node_execution_fields = { workflow_run_node_execution_list_fields = { "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), } + + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunForArchivedLogResponse(ResponseModel): + id: str + status: str | None = None + triggered_from: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + +class WorkflowRunForListResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_by_account: SimpleAccount | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + retry_index: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AdvancedChatWorkflowRunForListResponse(WorkflowRunForListResponse): + conversation_id: str | None = None + message_id: str | None = None + + +class AdvancedChatWorkflowRunPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[AdvancedChatWorkflowRunForListResponse] + + +class WorkflowRunPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[WorkflowRunForListResponse] + + +class WorkflowRunCountResponse(ResponseModel): + total: int + running: int + succeeded: int + failed: int + stopped: int + partial_succeeded: int = Field(validation_alias="partial-succeeded") + + +class WorkflowRunDetailResponse(ResponseModel): + id: str + version: str | None = None + graph: Any = Field(validation_alias="graph_dict") + inputs: Any = Field(validation_alias="inputs_dict") + status: str | None = None + outputs: Any = Field(validation_alias="outputs_dict") + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunNodeExecutionResponse(ResponseModel): + id: str + index: int | None = None + predecessor_node_id: str | None = None + node_id: str | None = None + node_type: str | None = None + title: str | None = None + inputs: Any = Field(default=None, validation_alias="inputs_dict") + process_data: Any = Field(default=None, validation_alias="process_data_dict") + outputs: Any = Field(default=None, validation_alias="outputs_dict") + status: str | None = None + error: str | None = None + elapsed_time: float | None = None + execution_metadata: Any = Field(default=None, validation_alias="execution_metadata_dict") + extras: Any = None + created_at: int | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + finished_at: int | None = None + inputs_truncated: bool | None = None + outputs_truncated: bool | None = None + process_data_truncated: bool | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunNodeExecutionListResponse(ResponseModel): + data: list[WorkflowRunNodeExecutionResponse] diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 36aa1cd3e8..b76a23eb3c 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -32,12 +33,13 @@ class Topic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.publish(self._topic, payload) + self._client.publish(self._redis_topic, payload) def as_subscriber(self) -> Subscriber: return self @@ -46,7 +48,7 @@ class Topic: return _RedisSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index dddc92d099..919d8d622e 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -30,12 +31,13 @@ class ShardedTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr] + self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr] def as_subscriber(self) -> Subscriber: return self @@ -44,7 +46,7 @@ class ShardedTopic: return _RedisShardedSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 983f785027..55ff6cd4f9 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -6,6 +6,7 @@ import threading from collections.abc import Iterator from typing import Self +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from libs.broadcast_channel.exc import SubscriptionClosedError from redis import Redis, RedisCluster @@ -35,7 +36,7 @@ class StreamsTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): self._client = redis_client self._topic = topic - self._key = f"stream:{topic}" + self._key = serialize_redis_name(f"stream:{topic}") self._retention_seconds = retention_seconds self.max_length = 5000 diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py index ca8956e397..b5fe38342a 100644 --- a/api/libs/db_migration_lock.py +++ b/api/libs/db_migration_lock.py @@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock: timeout=self._ttl_seconds, thread_local=False, ) - acquired = bool(self._lock.acquire(*args, **kwargs)) + lock = self._lock + if lock is None: + raise RuntimeError("Redis lock initialization failed.") + acquired = bool(lock.acquire(*args, **kwargs)) self._acquired = acquired if acquired: self._start_heartbeat() diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 0828cf80bf..1519f07bb1 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -37,6 +37,7 @@ class EmailType(StrEnum): ENTERPRISE_CUSTOM = auto() QUEUE_MONITOR_ALERT = auto() DOCUMENT_CLEAN_NOTIFY = auto() + WORKFLOW_COMMENT_MENTION = auto() EMAIL_REGISTER = auto() EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() @@ -453,6 +454,18 @@ def create_default_email_config() -> EmailI18nConfig: branded_template_path="clean_document_job_mail_template_zh-CN.html", ), }, + EmailType.WORKFLOW_COMMENT_MENTION: { + EmailLanguage.EN_US: EmailTemplate( + subject="You were mentioned in a workflow comment", + template_path="workflow_comment_mention_template_en-US.html", + branded_template_path="without-brand/workflow_comment_mention_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="你在工作流评论中被提及", + template_path="workflow_comment_mention_template_zh-CN.html", + branded_template_path="without-brand/workflow_comment_mention_template_zh-CN.html", + ), + }, EmailType.TRIGGER_EVENTS_LIMIT_SANDBOX: { EmailLanguage.EN_US: EmailTemplate( subject="You’ve reached your Sandbox Trigger Events limit", diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py index 52fc787c79..838af2bf32 100644 --- a/api/libs/flask_utils.py +++ b/api/libs/flask_utils.py @@ -1,5 +1,5 @@ import contextvars -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from typing import TYPE_CHECKING @@ -13,7 +13,7 @@ if TYPE_CHECKING: def preserve_flask_contexts( flask_app: Flask, context_vars: contextvars.Context, -) -> Iterator[None]: +) -> Generator[None, None, None]: # Changed from Iterator[None] """ A context manager that handles: 1. flask-login's UserProxy copy diff --git a/api/libs/helper.py b/api/libs/helper.py index f28de92927..ac69a11084 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,8 +16,6 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, TypeAdapter from pydantic.functional_validators import AfterValidator from typing_extensions import TypedDict @@ -25,6 +23,8 @@ from typing_extensions import TypedDict from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account @@ -120,10 +120,22 @@ class AppIconUrlField(fields.Raw): obj = obj["app"] if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: - return file_helpers.get_signed_file_url(obj.icon) + return build_icon_url(obj.icon_type, obj.icon) return None +def build_icon_url(icon_type: Any, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + + from models.model import IconType + + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) + + class AvatarUrlField(fields.Raw): def output(self, key, obj, **kwargs): if obj is None: diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 9b53918f24..934aacb45b 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -6,8 +6,8 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.db.session_factory import session_factory from core.helper.http_client_pooling import get_pooled_http_client -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) @@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def sync_data_source(self, binding_id: str) -> None: # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) ) - ) - if data_source_binding: - # get all authorized pages - pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) - new_source_info = self._build_source_info( - workspace_name=source_info["workspace_name"], - workspace_icon=source_info["workspace_icon"], - workspace_id=source_info["workspace_id"], - pages=pages, - ) - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - raise ValueError("Data source binding not found") + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: pages: list[NotionPageSummary] = [] diff --git a/api/libs/token.py b/api/libs/token.py index a34db70764..5b043465ac 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -47,23 +47,17 @@ def _cookie_domain() -> str | None: def _real_cookie_name(cookie_name: str) -> str: if is_secure() and _cookie_domain() is None: return "__Host-" + cookie_name - else: - return cookie_name + return cookie_name def _try_extract_from_header(request: Request) -> str | None: auth_header = request.headers.get("Authorization") - if auth_header: - if " " not in auth_header: - return None - else: - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - return None - else: - return auth_token - return None + if not auth_header or " " not in auth_header: + return None + auth_scheme, auth_token = auth_header.split(None, 1) + if auth_scheme.lower() != "bearer": + return None + return auth_token def extract_refresh_token(request: Request) -> str | None: @@ -90,14 +84,9 @@ def extract_webapp_access_token(request: Request) -> str | None: def extract_webapp_passport(app_code: str, request: Request) -> str | None: - def _try_extract_passport_token_from_cookie(request: Request) -> str | None: - return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) - - def _try_extract_passport_token_from_header(request: Request) -> str | None: - return request.headers.get(HEADER_NAME_PASSPORT) - - ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request) - return ret + return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) or request.headers.get( + HEADER_NAME_PASSPORT + ) def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"): @@ -209,22 +198,18 @@ def check_csrf_token(request: Request, user_id: str): if not csrf_token: _unauthorized() - verified = {} try: verified = PassportService().verify(csrf_token) - except: + except Exception: _unauthorized() + raise # unreachable, but helps the type checker see verified is always bound if verified.get("sub") != user_id: _unauthorized() exp: int | None = verified.get("exp") - if not exp: + if not exp or exp < int(datetime.now(UTC).timestamp()): _unauthorized() - else: - time_now = int(datetime.now().timestamp()) - if exp < time_now: - _unauthorized() def generate_csrf_token(user_id: str) -> str: diff --git a/api/libs/url_utils.py b/api/libs/url_utils.py new file mode 100644 index 0000000000..adcac3add0 --- /dev/null +++ b/api/libs/url_utils.py @@ -0,0 +1,3 @@ +def normalize_api_base_url(base_url: str) -> str: + """Normalize a base URL to always end with /v1, avoiding double /v1 suffixes.""" + return base_url.rstrip("/").removesuffix("/v1").rstrip("/") + "/v1" diff --git a/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py new file mode 100644 index 0000000000..0e188ec080 --- /dev/null +++ b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py @@ -0,0 +1,26 @@ +"""add qdrant_endpoint to tidb_auth_bindings + +Revision ID: 8574b23a38fd +Revises: 6b5f9f8b1a2c +Create Date: 2026-04-14 15:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8574b23a38fd" +down_revision = "6b5f9f8b1a2c" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True)) + + +def downgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.drop_column("qdrant_endpoint") diff --git a/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py new file mode 100644 index 0000000000..0548c932b5 --- /dev/null +++ b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py @@ -0,0 +1,90 @@ +"""Add workflow comments table + +Revision ID: 227822d22895 +Revises: 8574b23a38fd +Create Date: 2025-08-22 17:26:15.255980 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '227822d22895' +down_revision = '8574b23a38fd' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_comments', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('position_x', sa.Float(), nullable=False), + sa.Column('position_y', sa.Float(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('resolved_at', sa.DateTime(), nullable=True), + sa.Column('resolved_by', models.types.StringUUID(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey') + ) + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_replies', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey') + ) + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_mentions', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('reply_id', models.types.StringUUID(), nullable=True), + sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'), + sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey') + ) + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False) + batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.drop_index('comment_mentions_user_idx') + batch_op.drop_index('comment_mentions_reply_idx') + batch_op.drop_index('comment_mentions_comment_idx') + + op.drop_table('workflow_comment_mentions') + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.drop_index('comment_replies_created_at_idx') + batch_op.drop_index('comment_replies_comment_idx') + + op.drop_table('workflow_comment_replies') + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.drop_index('workflow_comments_created_at_idx') + batch_op.drop_index('workflow_comments_app_idx') + + op.drop_table('workflow_comments') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index fcae07f948..85be9ca3bd 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -9,6 +9,11 @@ from .account import ( TenantStatus, ) from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .comment import ( + WorkflowComment, + WorkflowCommentMention, + WorkflowCommentReply, +) from .dataset import ( AppDatasetJoin, Dataset, @@ -208,6 +213,9 @@ __all__ = [ "WorkflowAppLog", "WorkflowAppLogCreatedFrom", "WorkflowArchiveLog", + "WorkflowComment", + "WorkflowCommentMention", + "WorkflowCommentReply", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/comment.py b/api/models/comment.py new file mode 100644 index 0000000000..5d4a08e783 --- /dev/null +++ b/api/models/comment.py @@ -0,0 +1,219 @@ +"""Workflow comment models.""" + +from datetime import datetime +from typing import Optional + +import sqlalchemy as sa +from sqlalchemy import Index, func +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .account import Account +from .base import Base, gen_uuidv7_string +from .engine import db +from .types import StringUUID + + +class WorkflowComment(Base): + """Workflow comment model for canvas commenting functionality. + + Comments are associated with apps rather than specific workflow versions, + since an app has only one draft workflow at a time and comments should persist + across workflow version changes. + + Attributes: + id: Comment ID + tenant_id: Workspace ID + app_id: App ID (primary association, comments belong to apps) + position_x: X coordinate on canvas + position_y: Y coordinate on canvas + content: Comment content + created_by: Creator account ID + created_at: Creation time + updated_at: Last update time + resolved: Whether comment is resolved + resolved_at: Resolution time + resolved_by: Resolver account ID + """ + + __tablename__ = "workflow_comments" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + Index("workflow_comments_app_idx", "tenant_id", "app_id"), + Index("workflow_comments_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + position_x: Mapped[float] = mapped_column(sa.Float) + position_y: Mapped[float] = mapped_column(sa.Float) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) + resolved_by: Mapped[str | None] = mapped_column(StringUUID) + + # Relationships + replies: Mapped[list["WorkflowCommentReply"]] = relationship( + "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + ) + mentions: Mapped[list["WorkflowCommentMention"]] = relationship( + "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + ) + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + @property + def resolved_by_account(self): + """Get resolver account.""" + if hasattr(self, "_resolved_by_account_cache"): + return self._resolved_by_account_cache + if self.resolved_by: + return db.session.get(Account, self.resolved_by) + return None + + def cache_resolved_by_account(self, account: Account | None) -> None: + """Cache resolver account to avoid extra queries.""" + self._resolved_by_account_cache = account + + @property + def reply_count(self): + """Get reply count.""" + return len(self.replies) + + @property + def mention_count(self): + """Get mention count.""" + return len(self.mentions) + + @property + def participants(self): + """Get all participants (creator + repliers + mentioned users).""" + participant_ids: set[str] = set() + participants: list[Account] = [] + + # Use account properties to reuse preloaded caches and avoid hidden N+1. + if self.created_by not in participant_ids: + participant_ids.add(self.created_by) + created_by_account = self.created_by_account + if created_by_account: + participants.append(created_by_account) + + for reply in self.replies: + if reply.created_by in participant_ids: + continue + participant_ids.add(reply.created_by) + reply_account = reply.created_by_account + if reply_account: + participants.append(reply_account) + + for mention in self.mentions: + if mention.mentioned_user_id in participant_ids: + continue + participant_ids.add(mention.mentioned_user_id) + mentioned_account = mention.mentioned_user_account + if mentioned_account: + participants.append(mentioned_account) + + return participants + + +class WorkflowCommentReply(Base): + """Workflow comment reply model. + + Attributes: + id: Reply ID + comment_id: Parent comment ID + content: Reply content + created_by: Creator account ID + created_at: Creation time + """ + + __tablename__ = "workflow_comment_replies" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + Index("comment_replies_comment_idx", "comment_id"), + Index("comment_replies_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + comment_id: Mapped[str] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + +class WorkflowCommentMention(Base): + """Workflow comment mention model. + + Mentions are only for internal accounts since end users + cannot access workflow canvas and commenting features. + + Attributes: + id: Mention ID + comment_id: Parent comment ID + mentioned_user_id: Mentioned account ID + """ + + __tablename__ = "workflow_comment_mentions" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + Index("comment_mentions_comment_idx", "comment_id"), + Index("comment_mentions_reply_idx", "reply_id"), + Index("comment_mentions_user_idx", "mentioned_user_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + comment_id: Mapped[str] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + reply_id: Mapped[str | None] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + ) + mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + + @property + def mentioned_user_account(self): + """Get mentioned account.""" + if hasattr(self, "_mentioned_user_account_cache"): + return self._mentioned_user_account_cache + return db.session.get(Account, self.mentioned_user_id) + + def cache_mentioned_user_account(self, account: Account | None) -> None: + """Cache mentioned account to avoid extra queries.""" + self._mentioned_user_account_cache = account diff --git a/api/models/dataset.py b/api/models/dataset.py index a48afa7ea7..a00e9f7640 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1036,7 +1036,7 @@ class DocumentSegment(Base): return attachment_list -class ChildChunk(Base): +class ChildChunk(TypeBase): __tablename__ = "child_chunks" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"), @@ -1046,29 +1046,42 @@ class ChildChunk(Base): ) # initial fields - id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - dataset_id = mapped_column(StringUUID, nullable=False) - document_id = mapped_column(StringUUID, nullable=False) - segment_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - content = mapped_column(LongText, nullable=False) + content: Mapped[str] = mapped_column(LongText, nullable=False) word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) # indexing fields - index_node_id = mapped_column(String(255), nullable=True) - index_node_hash = mapped_column(String(255), nullable=True) - type: Mapped[SegmentType] = mapped_column( - EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) - created_by = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) - updated_by = mapped_column(StringUUID, nullable=True) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=sa.func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) - indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - error = mapped_column(LongText, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column( + DateTime, nullable=True, insert_default=None, server_default=None, init=False + ) + completed_at: Mapped[datetime | None] = mapped_column( + DateTime, nullable=True, insert_default=None, server_default=None, init=False + ) + index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), + nullable=False, + server_default=sa.text("'automatic'"), + default=SegmentType.AUTOMATIC, + ) + error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False) @property def dataset(self): @@ -1305,6 +1318,7 @@ class TidbAuthBinding(TypeBase): ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) + qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -1551,7 +1565,7 @@ class PipelineBuiltInTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False) @@ -1584,7 +1598,7 @@ class PipelineCustomizedTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1657,7 +1671,7 @@ class DocumentPipelineExecutionLog(TypeBase): datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) datasource_info: Mapped[str] = mapped_column(LongText, nullable=False) datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + input_data: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1714,7 +1728,7 @@ class SegmentAttachmentBinding(TypeBase): ) -class DocumentSegmentSummary(Base): +class DocumentSegmentSummary(TypeBase): __tablename__ = "document_segment_summaries" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), @@ -1724,25 +1738,40 @@ class DocumentSegmentSummary(Base): sa.Index("document_segment_summaries_status_idx", "status"), ) - id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # corresponds to DocumentSegment.id or parent chunk id chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - summary_content: Mapped[str] = mapped_column(LongText, nullable=True) - summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) - summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column( - EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + status: Mapped[SummaryStatus] = mapped_column( + EnumText(SummaryStatus, length=32), + nullable=False, + server_default=sa.text("'generating'"), + default=SummaryStatus.GENERATING, + ) + error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - error: Mapped[str] = mapped_column(LongText, nullable=True) - enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - disabled_by = mapped_column(StringUUID, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) def __repr__(self): diff --git a/api/models/human_input.py b/api/models/human_input.py index 79c5d62f6a..7447d3efcb 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.human_input_compat import DeliveryMethodType +from core.workflow.human_input_adapter import DeliveryMethodType +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 47b096d0bf..25c330b062 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,9 +14,6 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker @@ -24,7 +21,11 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] +from libs.url_utils import normalize_api_base_url from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping @@ -90,6 +91,19 @@ class EnabledConfig(TypedDict): enabled: bool +class SuggestedQuestionsAfterAnswerModelConfig(TypedDict): + provider: str + name: str + mode: NotRequired[str] + completion_params: NotRequired[dict[str, Any]] + + +class SuggestedQuestionsAfterAnswerConfig(TypedDict): + enabled: bool + model: NotRequired[SuggestedQuestionsAfterAnswerModelConfig] + prompt: NotRequired[str] + + class EmbeddingModelInfo(TypedDict): embedding_provider_name: str embedding_model_name: str @@ -219,7 +233,7 @@ class ModelConfig(TypedDict): class AppModelConfigDict(TypedDict): opening_statement: str | None suggested_questions: list[str] - suggested_questions_after_answer: EnabledConfig + suggested_questions_after_answer: SuggestedQuestionsAfterAnswerConfig speech_to_text: EnabledConfig text_to_speech: EnabledConfig retriever_resource: EnabledConfig @@ -446,7 +460,8 @@ class App(Base): @property def api_base_url(self) -> str: - return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return normalize_api_base_url(base) @property def tenant(self) -> Tenant | None: @@ -678,8 +693,13 @@ class AppModelConfig(TypeBase): return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled}) @property - def suggested_questions_after_answer_dict(self) -> EnabledConfig: - return self._get_enabled_config(self.suggested_questions_after_answer) + def suggested_questions_after_answer_dict(self) -> SuggestedQuestionsAfterAnswerConfig: + return cast( + SuggestedQuestionsAfterAnswerConfig, + json.loads(self.suggested_questions_after_answer) + if self.suggested_questions_after_answer + else {"enabled": False}, + ) @property def speech_to_text_dict(self) -> EnabledConfig: @@ -1007,7 +1027,7 @@ class OAuthProviderApp(TypeBase): app_icon: Mapped[str] = mapped_column(String(255), nullable=False) client_id: Mapped[str] = mapped_column(String(255), nullable=False) client_secret: Mapped[str] = mapped_column(String(255), nullable=False) - app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict) + app_label: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, default_factory=dict) redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list) scope: Mapped[str] = mapped_column( String(255), @@ -1847,15 +1867,18 @@ class MessageAnnotation(TypeBase): ) id: Mapped[str] = mapped_column( - StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + StringUUID, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, ) app_id: Mapped[str] = mapped_column(StringUUID) question: Mapped[str] = mapped_column(LongText, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None) message_id: Mapped[str | None] = mapped_column(StringUUID, default=None) - hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -2159,7 +2182,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. return result -class UploadFile(Base): +class UploadFile(TypeBase): __tablename__ = "upload_files" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -2167,9 +2190,12 @@ class UploadFile(Base): ) # NOTE: The `id` field is generated within the application to minimize extra roundtrips - # (especially when generating `source_url`). - # The `server_default` serves as a fallback mechanism. - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + # (especially when generating `source_url`) and keep model metadata portable across databases. + id: Mapped[str] = mapped_column( + StringUUID, + init=False, + default_factory=lambda: str(uuid4()), + ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False) key: Mapped[str] = mapped_column(String(255), nullable=False) @@ -2177,16 +2203,6 @@ class UploadFile(Base): size: Mapped[int] = mapped_column(sa.Integer, nullable=False) extension: Mapped[str] = mapped_column(String(255), nullable=False) mime_type: Mapped[str] = mapped_column(String(255), nullable=True) - - # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. - # Its value is derived from the `CreatorUserRole` enumeration. - created_by_role: Mapped[CreatorUserRole] = mapped_column( - EnumText(CreatorUserRole, length=255), - nullable=False, - server_default=sa.text("'account'"), - default=CreatorUserRole.ACCOUNT, - ) - # The `created_by` field stores the ID of the entity that created this upload file. # # If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`. @@ -2205,10 +2221,18 @@ class UploadFile(Base): # `used` may indicate whether the file has been utilized by another service. used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. + # Its value is derived from the `CreatorUserRole` enumeration. + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), + nullable=False, + server_default=sa.text("'account'"), + default=CreatorUserRole.ACCOUNT, + ) # `used_by` may indicate the ID of the user who utilized this file. - used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(String(255), nullable=True) + used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True, default=None) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) source_url: Mapped[str] = mapped_column(LongText, default="") def __init__( @@ -2495,7 +2519,7 @@ class TraceAppConfig(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True) - tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True) + tracing_config: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/models/oauth.py b/api/models/oauth.py index 1db2552469..bd04d890d3 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any import sqlalchemy as sa from sqlalchemy import func @@ -22,7 +23,7 @@ class DatasourceOauthParamConfig(TypeBase): ) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) - system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) class DatasourceProvider(TypeBase): @@ -40,7 +41,7 @@ class DatasourceProvider(TypeBase): provider: Mapped[str] = mapped_column(sa.String(128), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) - encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default") is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1) @@ -70,7 +71,7 @@ class DatasourceOauthTenantParamConfig(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) + client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/provider.py b/api/models/provider.py index 8270961b31..2bb67d605b 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,10 +6,10 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column +from graphon.model_runtime.entities.model_entities import ModelType from libs.uuid_utils import uuidv7 from .base import TypeBase diff --git a/api/models/source.py b/api/models/source.py index 8078b32f8c..8fce7df205 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -25,7 +25,7 @@ class DataSourceOauthBinding(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) access_token: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) - source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + source_info: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index 8b767779ce..77dcbd13d4 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -4,9 +4,9 @@ from collections.abc import Callable, Mapping from functools import lru_cache from typing import Any -from graphon.file import File, FileTransferMethod - from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object @lru_cache(maxsize=1) @@ -44,6 +44,124 @@ def resolve_file_mapping_tenant_id( return tenant_resolver() +def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File: + """Build a graph `File` directly from serialized metadata.""" + + def _coerce_file_type(value: Any) -> FileType: + if isinstance(value, FileType): + return value + if isinstance(value, str): + return FileType.value_of(value) + raise ValueError("file type is required in file mapping") + + mapping = dict(file_mapping) + transfer_method_value = mapping.get("transfer_method") + if isinstance(transfer_method_value, FileTransferMethod): + transfer_method = transfer_method_value + elif isinstance(transfer_method_value, str): + transfer_method = FileTransferMethod.value_of(transfer_method_value) + else: + raise ValueError("transfer_method is required in file mapping") + + file_id = mapping.get("file_id") + if not isinstance(file_id, str) or not file_id: + legacy_id = mapping.get("id") + file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None + + related_id = resolve_file_record_id(mapping) + if related_id is None: + raw_related_id = mapping.get("related_id") + related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None + + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + remote_url = url if isinstance(url, str) and url else None + + reference = mapping.get("reference") + if not isinstance(reference, str) or not reference: + reference = None + + filename = mapping.get("filename") + if not isinstance(filename, str): + filename = None + + extension = mapping.get("extension") + if not isinstance(extension, str): + extension = None + + mime_type = mapping.get("mime_type") + if not isinstance(mime_type, str): + mime_type = None + + size = mapping.get("size", -1) + if not isinstance(size, int): + size = -1 + + storage_key = mapping.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + tenant_id = mapping.get("tenant_id") + if not isinstance(tenant_id, str): + tenant_id = None + + dify_model_identity = mapping.get("dify_model_identity") + if not isinstance(dify_model_identity, str): + dify_model_identity = FILE_MODEL_IDENTITY + + tool_file_id = mapping.get("tool_file_id") + if not isinstance(tool_file_id, str): + tool_file_id = None + + upload_file_id = mapping.get("upload_file_id") + if not isinstance(upload_file_id, str): + upload_file_id = None + + datasource_file_id = mapping.get("datasource_file_id") + if not isinstance(datasource_file_id, str): + datasource_file_id = None + + return File( + file_id=file_id, + tenant_id=tenant_id, + file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))), + transfer_method=transfer_method, + remote_url=remote_url, + reference=reference, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + storage_key=storage_key, + dify_model_identity=dify_model_identity, + url=remote_url, + tool_file_id=tool_file_id, + upload_file_id=upload_file_id, + datasource_file_id=datasource_file_id, + ) + + +def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any: + """Recursively rebuild serialized graph file payloads into `File` objects. + + `graphon` 0.2.2 no longer accepts legacy serialized file mappings via + `model_validate_json()`. Dify keeps this recovery path at the model boundary + so historical JSON blobs remain readable without reintroducing global graph + patches or test-local coercion. + """ + if isinstance(value, list): + return [rebuild_serialized_graph_files_without_lookup(item) for item in value] + + if isinstance(value, dict): + if maybe_file_object(value): + return build_file_from_mapping_without_lookup(file_mapping=value) + return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} + + return value + + def build_file_from_stored_mapping( *, file_mapping: Mapping[str, Any], @@ -77,12 +195,7 @@ def build_file_from_stored_mapping( pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: - remote_url = mapping.get("remote_url") - if not isinstance(remote_url, str) or not remote_url: - url = mapping.get("url") - if isinstance(url, str) and url: - mapping["remote_url"] = url - return File.model_validate(mapping) + return build_file_from_mapping_without_lookup(file_mapping=mapping) return file_factory.build_from_mapping( mapping=mapping, diff --git a/api/models/workflow.py b/api/models/workflow.py index 63abf8c3b6..cb1723440b 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,19 +8,6 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.file.constants import maybe_file_object -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -37,35 +24,54 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: - from .model import AppMode, UploadFile + from .model import AppMode -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase - from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account from .base import Base, DefaultFieldsDCMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom + +# UploadFile uses TypeBase while workflow execution offload models use Base, so relationships +# must target the class object directly instead of relying on string lookup across registries. +from .model import UploadFile from .types import EnumText, LongText, StringUUID -from .utils.file_input_compat import build_file_from_stored_mapping +from .utils.file_input_compat import ( + build_file_from_mapping_without_lookup, + build_file_from_stored_mapping, +) logger = logging.getLogger(__name__) @@ -291,7 +297,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -490,7 +496,7 @@ class Workflow(Base): # bug :return: hash """ - entity = {"graph": self.graph_dict, "features": self.features_dict} + entity = {"graph": self.graph_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @@ -1094,8 +1100,6 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def _load_full_content(session: orm.Session, file_id: str, storage: Storage): - from .model import UploadFile - stmt = sa.select(UploadFile).where(UploadFile.id == file_id) file = session.scalars(stmt).first() assert file is not None, f"UploadFile with id {file_id} should exist but not" @@ -1189,10 +1193,11 @@ class WorkflowNodeExecutionOffload(Base): ) file: Mapped[Optional["UploadFile"]] = orm.relationship( + UploadFile, foreign_keys=[file_id], lazy="raise", uselist=False, - primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id", + primaryjoin=lambda: orm.foreign(WorkflowNodeExecutionOffload.file_id) == UploadFile.id, ) @@ -1689,7 +1694,7 @@ class WorkflowDraftVariable(Base): return cast(Any, value) normalized_file = dict(value) normalized_file.pop("tenant_id", None) - return File.model_validate(normalized_file) + return build_file_from_mapping_without_lookup(file_mapping=normalized_file) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] @@ -1699,7 +1704,7 @@ class WorkflowDraftVariable(Base): for item in value_list: normalized_file = dict(cast(dict[str, Any], item)) normalized_file.pop("tenant_id", None) - file_list.append(File.model_validate(normalized_file)) + file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) return cast(Any, file_list) else: return cast(Any, value) @@ -1966,10 +1971,11 @@ class WorkflowDraftVariableFile(Base): # Relationship to UploadFile upload_file: Mapped["UploadFile"] = orm.relationship( + UploadFile, foreign_keys=[upload_file_id], lazy="raise", uselist=False, - primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id", + primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id, ) diff --git a/api/providers/README.md b/api/providers/README.md index a00ec8bc52..5d5e6db9af 100644 --- a/api/providers/README.md +++ b/api/providers/README.md @@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Dify’s API Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`). +## Excluding Providers + +In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers. diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md new file mode 100644 index 0000000000..a7ffa5ed26 --- /dev/null +++ b/api/providers/trace/README.md @@ -0,0 +1,78 @@ +# Trace providers + +This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others). + +Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping. + +## Architecture + +| Layer | Location | Role | +|--------|----------|------| +| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads | +| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class | +| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` | + +At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype. + +## What you implement + +### 1. Config model (`BaseTracingConfig`) + +Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate. + +Fields fall into two groups used by the manager: + +- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords). +- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints). + +List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct. + +### 2. Trace instance (`BaseTraceInstance`) + +Subclass `BaseTraceInstance` and implement: + +```python +def trace(self, trace_info: BaseTraceInfo) -> None: + ... +``` + +Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including: + +- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace` +- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo` +- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo` + +You may ignore categories your backend does not support; existing providers often no-op unhandled types. + +Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context. + +### 3. Register in the API core + +Upstream changes are required so Dify knows your provider exists: + +1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`). +2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning: + - `config_class`: your Pydantic config type + - `secret_keys` / `other_keys`: lists of field names as above + - `trace_instance`: your `BaseTraceInstance` subclass + Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`. + +If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app. + +## Package layout + +Each provider is a normal uv workspace member, for example: + +- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs +- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`. + +Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`. + +## Wiring into the `api` workspace + +In `api/pyproject.toml`: + +1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }` +2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle + +After changing metadata, run **`uv sync`** from `api/`. diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml new file mode 100644 index 0000000000..bcef7e9fb1 --- /dev/null +++ b/api/providers/trace/trace-aliyun/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-aliyun" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Aliyun)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py similarity index 98% rename from api/core/ops/aliyun_trace/aliyun_trace.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py index 70aaf2a07b..54d2f8167f 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py @@ -1,12 +1,23 @@ import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.data_exporter.traceclient import ( TraceClient, build_endpoint, convert_datetime_to_nanoseconds, @@ -14,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( convert_to_trace_id, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( DIFY_APP_ID, GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -34,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -46,20 +57,9 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) -from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import AliyunConfig -from core.ops.entities.trace_entity import ( - BaseTraceInfo, - DatasetRetrievalTraceInfo, - GenerateNameTraceInfo, - MessageTraceInfo, - ModerationTraceInfo, - SuggestedQuestionTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py new file mode 100644 index 0000000000..e0133e6cc9 --- /dev/null +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py @@ -0,0 +1,32 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class AliyunConfig(BaseTracingConfig): + """ + Model class for Aliyun tracing config. + """ + + app_name: str = "dify_app" + license_key: str + endpoint: str + + @field_validator("app_name") + @classmethod + def app_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") + + @field_validator("license_key") + @classmethod + def license_key_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("License key cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # aliyun uses two URL formats, which may include a URL path + return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/data_exporter/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py similarity index 98% rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py index 67d5163b0f..00aab6bf89 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py @@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData -from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE INVALID_SPAN_ID: Final[int] = 0x0000000000000000 INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/semconv.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed similarity index 100% rename from api/core/ops/arize_phoenix_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py similarity index 97% rename from api/core/ops/aliyun_trace/utils.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py index aa35ac74c2..5678c66adb 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py @@ -2,11 +2,10 @@ import json from collections.abc import Mapping from typing import Any, TypedDict -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode -from core.ops.aliyun_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -15,8 +14,9 @@ from core.ops.aliyun_trace.entities.semconv import ( OUTPUT_VALUE, GenAISpanKind, ) -from core.rag.models.document import Document from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants @@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status: def create_links_from_trace_id(trace_id: str | None) -> list[Link]: - from core.ops.aliyun_trace.data_exporter.traceclient import create_link + from dify_trace_aliyun.data_exporter.traceclient import create_link links = [] if trace_id: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py similarity index 86% rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index acb43d4036..ac09060e9d 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.trace import SpanKind, Status, StatusCode - -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from dify_trace_aliyun.data_exporter.traceclient import ( INVALID_SPAN_ID, SpanBuilder, TraceClient, @@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( create_link, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode @pytest.fixture @@ -41,8 +40,8 @@ def trace_client_factory(): class TestTraceClient: - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname") def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): mock_gethostname.return_value = "test-host" client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -56,7 +55,7 @@ class TestTraceClient: client.shutdown() assert client.done is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -64,8 +63,8 @@ class TestTraceClient: client.export(spans) mock_exporter.export.assert_called_once_with(spans) - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 405 @@ -74,8 +73,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 500 @@ -84,8 +83,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is False - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): mock_head.side_effect = httpx.RequestError("Connection error") @@ -93,12 +92,12 @@ class TestTraceClient: with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): client.api_check() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_get_project_url(self, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_add_span(self, mock_exporter_class, trace_client_factory): client = trace_client_factory( service_name="test-service", @@ -134,8 +133,8 @@ class TestTraceClient: assert len(client.queue) == 2 mock_notify.assert_called_once() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.logger") def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) @@ -159,7 +158,7 @@ class TestTraceClient: assert len(client.queue) == 1 mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export_batch_error(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value mock_exporter.export.side_effect = Exception("Export failed") @@ -168,11 +167,11 @@ class TestTraceClient: mock_span = MagicMock(spec=ReadableSpan) client.queue.append(mock_span) - with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger: client._export_batch() mock_logger.warning.assert_called() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_worker_loop(self, mock_exporter_class, trace_client_factory): # We need to test the wait timeout in _worker # But _worker runs in a thread. Let's mock condition.wait. @@ -189,7 +188,7 @@ class TestTraceClient: # mock_wait might have been called assert mock_wait.called or client.done - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -226,8 +225,10 @@ class TestSpanBuilder: span = builder.build_span(span_data) assert isinstance(span, ReadableSpan) assert span.name == "test-span" + assert span.context is not None assert span.context.trace_id == 123 assert span.context.span_id == 456 + assert span.parent is not None assert span.parent.span_id == 789 assert span.resource == resource assert span.attributes == {"attr1": "val1"} @@ -268,7 +269,7 @@ def test_generate_span_id(): assert span_id != INVALID_SPAN_ID # Test retry loop - with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand: mock_rand.side_effect = [INVALID_SPAN_ID, 999] span_id = generate_span_id() assert span_id == 999 @@ -290,7 +291,7 @@ def test_convert_to_trace_id(): def test_convert_string_to_id(): assert convert_string_to_id("test") > 0 # Test with None string - with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen: mock_gen.return_value = 12345 assert convert_string_to_id(None) == 12345 diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py similarity index 90% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 2fcb927e0c..a6808fec0a 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -1,11 +1,10 @@ import pytest +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import ValidationError -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata - class TestTraceMetadata: def test_trace_metadata_init(self): @@ -65,12 +64,13 @@ class TestSpanData: def test_span_data_missing_required_fields(self): with pytest.raises(ValidationError): - SpanData( - trace_id=123, - # span_id missing - name="test_span", - start_time=1000, - end_time=2000, + SpanData.model_validate( + { + "trace_id": 123, + "name": "test_span", + "start_time": 1000, + "end_time": 2000, + } ) def test_span_data_arbitrary_types_allowed(self): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py index 3961555b9a..9cab40748f 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py @@ -1,4 +1,4 @@ -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( ACS_ARMS_SERVICE_FEATURE, GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py similarity index 90% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index 62d631a754..fa00829653 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -2,16 +2,15 @@ from __future__ import annotations from datetime import UTC, datetime from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock +import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey -from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags - -import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module -from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.aliyun_trace import AliyunDataTrace +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, GEN_AI_OUTPUT_MESSAGE, @@ -26,7 +25,8 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.entities.config_entity import AliyunConfig +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -36,6 +36,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: @@ -44,7 +46,7 @@ class RecordingTraceClient: self.endpoint = endpoint self.added_spans: list[object] = [] - def add_span(self, span) -> None: + def add_span(self, span: object) -> None: self.added_spans.append(span) def api_check(self) -> bool: @@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link: trace_id=trace_id, span_id=span_id, is_remote=False, - trace_flags=TraceFlags.SAMPLED, + trace_flags=TraceFlags(TraceFlags.SAMPLED), ) return Link(context) +def _make_trace_metadata( + trace_id: int = 1, + workflow_span_id: int = 2, + session_id: str = "s", + user_id: str = "u", + links: list[Link] | None = None, +) -> TraceMetadata: + return TraceMetadata( + trace_id=trace_id, + workflow_span_id=workflow_span_id, + session_id=session_id, + user_id=user_id, + links=[] if links is None else links, + ) + + +def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient: + return cast(RecordingTraceClient, trace_instance.trace_client) + + +def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]: + return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans) + + def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: defaults = { "workflow_id": "workflow-id", @@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT trace_instance.workflow_trace(trace_info) add_workflow_span.assert_called_once() - passed_trace_metadata = add_workflow_span.call_args.args[1] + passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1]) assert passed_trace_metadata.trace_id == 111 assert passed_trace_metadata.workflow_span_id == 222 assert passed_trace_metadata.session_id == "c" assert passed_trace_metadata.user_id == "u" assert passed_trace_metadata.links == [] - assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] + assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"] def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_message_trace_info(message_data=None) trace_instance.message_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT ) trace_instance.message_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span, llm_span = trace_instance.trace_client.added_spans + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span, llm_span = spans assert message_span.name == "message" assert message_span.trace_id == 10 @@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_dataset_retrieval_trace_info(message_data=None) trace_instance.dataset_retrieval_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "dataset_retrieval" assert span.attributes[RETRIEVAL_QUERY] == "query" assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' @@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_tool_trace_info(message_data=None) trace_instance.tool_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p ) ) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "my-tool" assert span.status == status assert span.attributes[TOOL_NAME] == "my-tool" @@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) @@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) @@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) @@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) @@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) node_execution.node_type = BuiltinNodeTypes.CODE @@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "title" @@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) + trace_metadata = _make_trace_metadata(links=[_make_link()]) node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "my-tool" @@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] ) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "retrieval" @@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "llm" @@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() # CASE 1: With message_id trace_info = _make_workflow_trace_info( @@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. ) trace_instance.add_workflow_span(trace_info, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span = trace_instance.trace_client.added_spans[0] - workflow_span = trace_instance.trace_client.added_spans[1] + client = _recording_trace_client(trace_instance) + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span = spans[0] + workflow_span = spans[1] assert message_span.name == "message" assert message_span.span_kind == SpanKind.SERVER @@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. assert workflow_span.span_kind == SpanKind.INTERNAL assert workflow_span.parent_span_id == 20 - trace_instance.trace_client.added_spans.clear() + client.added_spans.clear() # CASE 2: Without message_id trace_info_no_msg = _make_workflow_trace_info(message_id=None) trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "workflow" assert span.span_kind == SpanKind.SERVER assert span.parent_span_id is None @@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) trace_instance.suggested_question_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "suggested_question" assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py similarity index 92% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index 2d2be12f05..1b97746dea 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,11 +1,9 @@ import json +from collections.abc import Mapping +from typing import Any, cast from unittest.mock import MagicMock -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus -from opentelemetry.trace import Link, StatusCode - -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -13,7 +11,7 @@ from core.ops.aliyun_trace.entities.semconv import ( INPUT_VALUE, OUTPUT_VALUE, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -25,7 +23,11 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) +from opentelemetry.trace import Link, StatusCode + from core.rag.models.document import Document +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser @@ -48,7 +50,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = end_user_data - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -63,7 +65,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = None - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -112,9 +114,9 @@ def test_get_workflow_node_status(): def test_create_links_from_trace_id(monkeypatch): # Mock create_link mock_link = MagicMock(spec=Link) - import core.ops.aliyun_trace.data_exporter.traceclient + import dify_trace_aliyun.data_exporter.traceclient - monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) # Trace ID None assert create_links_from_trace_id(None) == [] @@ -170,7 +172,7 @@ def test_create_common_span_attributes(): def test_format_retrieval_documents(): # Not a list - assert format_retrieval_documents("not a list") == [] + assert format_retrieval_documents(cast(list[object], "not a list")) == [] # Valid list docs = [ @@ -211,7 +213,7 @@ def test_format_retrieval_documents(): def test_format_input_messages(): # Not a dict - assert format_input_messages(None) == serialize_json_data([]) + assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No prompts assert format_input_messages({}) == serialize_json_data([]) @@ -244,7 +246,7 @@ def test_format_input_messages(): def test_format_output_messages(): # Not a dict - assert format_output_messages(None) == serialize_json_data([]) + assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No text assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..8068ee1328 --- /dev/null +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -0,0 +1,85 @@ +import pytest +from dify_trace_aliyun.config import AliyunConfig +from pydantic import ValidationError + + +class TestAliyunConfig: + """Test cases for AliyunConfig""" + + def test_valid_config(self): + """Test valid Aliyun configuration""" + config = AliyunConfig( + app_name="test_app", + license_key="test_license_key", + endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", + ) + assert config.app_name == "test_app" + assert config.license_key == "test_license_key" + assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + assert config.app_name == "dify_app" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + AliyunConfig.model_validate({}) + + with pytest.raises(ValidationError): + AliyunConfig.model_validate({"license_key": "test_license"}) + + with pytest.raises(ValidationError): + AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"}) + + def test_app_name_validation_empty(self): + """Test app_name validation with empty value""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" + ) + assert config.app_name == "dify_app" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = AliyunConfig(license_key="test_license", endpoint="") + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation preserves path for Aliyun endpoints""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + ) + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_license_key_required(self): + """Test that license_key is required and cannot be empty""" + with pytest.raises(ValidationError): + AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_valid_endpoint_format_examples(self): + """Test valid endpoint format examples from comments""" + valid_endpoints = [ + # cms2.0 public endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", + # cms2.0 intranet endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", + # xtrace public endpoint + "http://tracing-cn-heyuan.arms.aliyuncs.com", + # xtrace intranet endpoint + "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", + ] + + for endpoint in valid_endpoints: + config = AliyunConfig(license_key="test_license", endpoint=endpoint) + assert config.endpoint == endpoint diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml new file mode 100644 index 0000000000..9e756944c9 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-arize-phoenix" +version = "0.0.1" +dependencies = [ + "arize-phoenix-otel~=0.15.0", +] +description = "Dify ops tracing provider (Arize / Phoenix)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py similarity index 100% rename from api/core/ops/langfuse_trace/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py similarity index 99% rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index dd5edde630..96df49ed0e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -6,7 +6,6 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -26,7 +25,6 @@ from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -40,7 +38,9 @@ from core.ops.entities.trace_entity import ( ) from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py new file mode 100644 index 0000000000..6eac5b30d2 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py @@ -0,0 +1,45 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class ArizeConfig(BaseTracingConfig): + """ + Model class for Arize tracing config. + """ + + api_key: str | None = None + space_id: str | None = None + project: str | None = None + endpoint: str = "https://otlp.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://otlp.arize.com") + + +class PhoenixConfig(BaseTracingConfig): + """ + Model class for Phoenix tracing config. + """ + + api_key: str | None = None + project: str | None = None + endpoint: str = "https://app.phoenix.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://app.phoenix.arize.com") diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed similarity index 100% rename from api/core/ops/langfuse_trace/entities/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 4ce9e22fd7..e9ecc2e083 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -1,12 +1,9 @@ from datetime import UTC, datetime, timedelta +from typing import cast from unittest.mock import MagicMock, patch import pytest -from opentelemetry.sdk.trace import Tracer -from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes -from opentelemetry.trace import StatusCode - -from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( +from dify_trace_arize_phoenix.arize_phoenix_trace import ( ArizePhoenixDataTrace, datetime_to_nanos, error_to_string, @@ -15,7 +12,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( setup_tracer, wrap_span_metadata, ) -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -80,7 +81,7 @@ def test_datetime_to_nanos(): expected = int(dt.timestamp() * 1_000_000_000) assert datetime_to_nanos(dt) == expected - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt: mock_now = MagicMock() mock_now.timestamp.return_value = 1704110400.0 mock_dt.now.return_value = mock_now @@ -129,7 +130,7 @@ def test_set_span_status(): return "SilentErrorRepr" span.reset_mock() - set_span_status(span, SilentError()) + set_span_status(span, cast(Exception | str | None, SilentError())) assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" @@ -142,8 +143,8 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") setup_tracer(config) @@ -151,8 +152,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter): assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_phoenix(mock_provider, mock_exporter): config = PhoenixConfig(endpoint="http://p.com", project="p") setup_tracer(config) @@ -162,7 +163,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter): def test_setup_tracer_exception(): config = ArizeConfig(endpoint="http://a.com", project="p") - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): with pytest.raises(Exception, match="boom"): setup_tracer(config) @@ -172,7 +173,7 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) @@ -228,9 +229,9 @@ def test_trace_exception(trace_instance): trace_instance.trace(_make_workflow_info()) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -262,7 +263,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac assert trace_instance.tracer.start_span.call_count >= 2 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_no_app_id(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -271,7 +272,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): trace_instance.workflow_trace(info) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_success(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() @@ -291,7 +292,7 @@ def test_message_trace_success(mock_db, trace_instance): assert trace_instance.tracer.start_span.call_count >= 1 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_with_error(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py index 6b5cb5b09a..a01c63ae61 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes +from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from openinference.semconv.trace import OpenInferenceSpanKindValues -from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..11e951c3b1 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py @@ -0,0 +1,88 @@ +import pytest +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from pydantic import ValidationError + + +class TestArizeConfig: + """Test cases for ArizeConfig""" + + def test_valid_config(self): + """Test valid Arize configuration""" + config = ArizeConfig( + api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" + ) + assert config.api_key == "test_key" + assert config.space_id == "test_space" + assert config.project == "test_project" + assert config.endpoint == "https://custom.arize.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = ArizeConfig() + assert config.api_key is None + assert config.space_id is None + assert config.project is None + assert config.endpoint == "https://otlp.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = ArizeConfig(project="") + assert config.project == "default" + + def test_project_validation_none(self): + """Test project validation with None value""" + config = ArizeConfig(project=None) + assert config.project == "default" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = ArizeConfig(endpoint="") + assert config.endpoint == "https://otlp.arize.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation normalizes URL by removing path""" + config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") + assert config.endpoint == "https://custom.arize.com" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="ftp://invalid.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="invalid.com") + + +class TestPhoenixConfig: + """Test cases for PhoenixConfig""" + + def test_valid_config(self): + """Test valid Phoenix configuration""" + config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.phoenix.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = PhoenixConfig() + assert config.api_key is None + assert config.project is None + assert config.endpoint == "https://app.phoenix.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = PhoenixConfig(project="") + assert config.project == "default" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml new file mode 100644 index 0000000000..27d2273a69 --- /dev/null +++ b/api/providers/trace/trace-langfuse/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langfuse" +version = "0.0.1" +dependencies = [ + "langfuse>=4.2.0,<5.0.0", +] +description = "Dify ops tracing provider (Langfuse)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py new file mode 100644 index 0000000000..90d1a2846b --- /dev/null +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py @@ -0,0 +1,19 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class LangfuseConfig(BaseTracingConfig): + """ + Model class for Langfuse tracing config. + """ + + public_key: str + secret_key: str + host: str = "https://api.langfuse.com" + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://api.langfuse.com") diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py similarity index 100% rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py similarity index 99% rename from api/core/ops/langfuse_trace/langfuse_trace.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py index d53aa84aed..68881378a7 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py @@ -3,7 +3,6 @@ import os import uuid from datetime import UTC, datetime, timedelta -from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse from langfuse.api import ( CreateGenerationBody, @@ -17,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -29,7 +27,10 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( +from core.ops.utils import filter_none_values +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( GenerationUsage, LangfuseGeneration, LangfuseSpan, @@ -37,9 +38,8 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( LevelEnum, UnitEnum, ) -from core.ops.utils import filter_none_values -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed similarity index 100% rename from api/core/ops/mlflow_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py index 374371fb42..952f10c34f 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py @@ -5,9 +5,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -18,14 +25,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( - LangfuseGeneration, - LangfuseSpan, - LangfuseTrace, - LevelEnum, - UnitEnum, -) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -43,7 +43,7 @@ def langfuse_config(): def trace_instance(langfuse_config, monkeypatch): # Mock Langfuse client to avoid network calls mock_client = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client) instance = LangFuseDataTrace(langfuse_config) return instance @@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch): def test_init(langfuse_config, monkeypatch): mock_langfuse = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangFuseDataTrace(langfuse_config) @@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): # Mock DB and Repositories mock_session = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="log-1", error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..0c3c3fc81e --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -0,0 +1,42 @@ +import pytest +from dify_trace_langfuse.config import LangfuseConfig +from pydantic import ValidationError + + +class TestLangfuseConfig: + """Test cases for LangfuseConfig""" + + def test_valid_config(self): + """Test valid Langfuse configuration""" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == "https://custom.langfuse.com" + + def test_valid_config_with_path(self): + host = "https://custom.langfuse.com/api/v1" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == host + + def test_default_values(self): + """Test default values are set correctly""" + config = LangfuseConfig(public_key="public", secret_key="secret") + assert config.host == "https://api.langfuse.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangfuseConfig.model_validate({}) + + with pytest.raises(ValidationError): + LangfuseConfig.model_validate({"public_key": "public"}) + + with pytest.raises(ValidationError): + LangfuseConfig.model_validate({"secret_key": "secret"}) + + def test_host_validation_empty(self): + """Test host validation with empty value""" + config = LangfuseConfig(public_key="public", secret_key="secret", host="") + assert config.host == "https://api.langfuse.com" diff --git a/api/tests/unit_tests/core/ops/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py similarity index 90% rename from api/tests/unit_tests/core/ops/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py index f8951d2b4a..82d69b6180 100644 --- a/api/tests/unit_tests/core/ops/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -2,17 +2,18 @@ from datetime import datetime, timedelta from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock, patch -from graphon.enums import BuiltinNodeTypes +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace +from graphon.enums import BuiltinNodeTypes def _create_trace_instance() -> LangFuseDataTrace: - with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True): + with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True): return LangFuseDataTrace( LangfuseConfig( public_key="public-key", @@ -117,9 +118,9 @@ class TestLangFuseDataTraceCompletionStartTime: patch.object(trace, "add_span"), patch.object(trace, "add_generation") as add_generation, patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()), - patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()), + patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()), patch( - "core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=repository, ), ): @@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime: assert trace._get_completion_start_time(start_time, None) is None assert trace._get_completion_start_time(start_time, -1) is None - assert trace._get_completion_start_time(start_time, "invalid") is None + assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml new file mode 100644 index 0000000000..8131952b28 --- /dev/null +++ b/api/providers/trace/trace-langsmith/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langsmith" +version = "0.0.1" +dependencies = [ + "langsmith~=0.7.30", +] +description = "Dify ops tracing provider (LangSmith)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py similarity index 100% rename from api/core/ops/opik_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py new file mode 100644 index 0000000000..498b8c5e7e --- /dev/null +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py @@ -0,0 +1,20 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class LangSmithConfig(BaseTracingConfig): + """ + Model class for Langsmith tracing config. + """ + + api_key: str + project: str + endpoint: str = "https://api.smith.langchain.com" + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # LangSmith only allows HTTPS + return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py similarity index 99% rename from api/core/ops/langsmith_trace/langsmith_trace.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index 490c64af84..145bd70dbc 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -4,13 +4,11 @@ import uuid from datetime import datetime, timedelta from typing import cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -22,14 +20,16 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( +from core.ops.utils import filter_none_values, generate_dotted_order +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( LangSmithRunModel, LangSmithRunType, LangSmithRunUpdateModel, ) -from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed similarity index 100% rename from api/core/ops/weave_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index bfe916f018..45e5894e4a 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -3,9 +3,14 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -16,12 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( - LangSmithRunModel, - LangSmithRunType, - LangSmithRunUpdateModel, -) -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -38,7 +38,7 @@ def langsmith_config(): def trace_instance(langsmith_config, monkeypatch): # Mock LangSmith client mock_client = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client) instance = LangSmithDataTrace(langsmith_config) return instance @@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch): def test_init(langsmith_config, monkeypatch): mock_client_class = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangSmithDataTrace(langsmith_config) @@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch): # Mock dependencies mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() @@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): trace_info.error = "" mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..bd226c9f1a --- /dev/null +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -0,0 +1,35 @@ +import pytest +from dify_trace_langsmith.config import LangSmithConfig +from pydantic import ValidationError + + +class TestLangSmithConfig: + """Test cases for LangSmithConfig""" + + def test_valid_config(self): + """Test valid LangSmith configuration""" + config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.smith.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = LangSmithConfig(api_key="key", project="project") + assert config.endpoint == "https://api.smith.langchain.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangSmithConfig.model_validate({}) + + with pytest.raises(ValidationError): + LangSmithConfig.model_validate({"api_key": "key"}) + + with pytest.raises(ValidationError): + LangSmithConfig.model_validate({"project": "project"}) + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml new file mode 100644 index 0000000000..fad6002944 --- /dev/null +++ b/api/providers/trace/trace-mlflow/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-mlflow" +version = "0.0.1" +dependencies = [ + "mlflow-skinny>=3.11.1", +] +description = "Dify ops tracing provider (MLflow / Databricks)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py similarity index 100% rename from api/core/ops/weave_trace/entities/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py new file mode 100644 index 0000000000..84914165e3 --- /dev/null +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py @@ -0,0 +1,46 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_integer_id, validate_url_with_path + + +class MLflowConfig(BaseTracingConfig): + """ + Model class for MLflow tracing config. + """ + + tracking_uri: str = "http://localhost:5000" + experiment_id: str = "0" # Default experiment id in MLflow is 0 + username: str | None = None + password: str | None = None + + @field_validator("tracking_uri") + @classmethod + def tracking_uri_validator(cls, v, info: ValidationInfo): + if isinstance(v, str) and v.startswith("databricks"): + raise ValueError( + "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." + ) + return validate_url_with_path(v, "http://localhost:5000") + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) + + +class DatabricksConfig(BaseTracingConfig): + """ + Model class for Databricks (Databricks-managed MLflow) tracing config. + """ + + experiment_id: str + host: str + client_id: str | None = None + client_secret: str | None = None + personal_access_token: str | None = None + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py similarity index 99% rename from api/core/ops/mlflow_trace/mlflow_trace.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index c070a937be..4e4c45a532 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -4,7 +4,6 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow -from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace @@ -12,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -25,7 +23,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import JSON_DICT_ADAPTER +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py similarity index 96% rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index f4c485a9fc..46c9750a5d 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" +"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module.""" from __future__ import annotations @@ -9,9 +9,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig +from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds +from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -179,7 +179,7 @@ def _make_node(**overrides): @pytest.fixture def mock_mlflow(): - with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock: yield mock @@ -187,10 +187,10 @@ def mock_mlflow(): def mock_tracing(): """Patch all MLflow tracing functions used by the module.""" with ( - patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, - patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, - patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, - patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start, + patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update, + patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set, + patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach, ): yield { "start": mock_start, @@ -202,7 +202,7 @@ def mock_tracing(): @pytest.fixture def mock_db(): - with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + with patch("dify_trace_mlflow.mlflow_trace.db") as mock: yield mock @@ -599,7 +599,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_instance.message_trace(_make_message_trace_info()) mock_tracing["start"].assert_called_once() @@ -609,7 +608,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(error="something broke") trace_instance.message_trace(trace_info) @@ -620,7 +618,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None monkeypatch.setenv("FILES_URL", "http://files.test") file_data = SimpleNamespace(url="path/to/file.png") @@ -638,7 +635,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(file_list=None, message_file_data=None) trace_instance.message_trace(trace_info) @@ -651,7 +647,6 @@ class TestMessageTrace: end_user = MagicMock() end_user.session_id = "session-xyz" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user trace_info = _make_message_trace_info( metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, @@ -664,7 +659,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info( metadata={"from_account_id": "acc-1"}, diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml new file mode 100644 index 0000000000..874997168e --- /dev/null +++ b/api/providers/trace/trace-opik/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-opik" +version = "0.0.1" +dependencies = [ + "opik~=1.11.2", +] +description = "Dify ops tracing provider (Opik)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py new file mode 100644 index 0000000000..c16ff1d903 --- /dev/null +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py @@ -0,0 +1,25 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class OpikConfig(BaseTracingConfig): + """ + Model class for Opik tracing config. + """ + + api_key: str | None = None + project: str | None = None + workspace: str | None = None + url: str = "https://www.comet.com/opik/api/" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "Default Project") + + @field_validator("url") + @classmethod + def url_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py similarity index 99% rename from api/core/ops/opik_trace/opik_trace.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py index e0c7b9bfe5..2d124ac989 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py @@ -5,13 +5,11 @@ import uuid from datetime import datetime, timedelta from typing import Any, cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -24,7 +22,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory +from dify_trace_opik.config import OpikConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py index 1cb32f2ee0..eefed3c78c 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py @@ -5,9 +5,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_opik.config import OpikConfig +from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -18,7 +18,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -37,7 +37,7 @@ def opik_config(): @pytest.fixture def trace_instance(opik_config, monkeypatch): mock_client = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client) instance = OpikDataTrace(opik_config) return instance @@ -67,7 +67,7 @@ def test_prepare_opik_uuid(): def test_init(opik_config, monkeypatch): mock_opik = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik) monkeypatch.setenv("FILES_URL", "http://test.url") instance = OpikDataTrace(opik_config) @@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) node_llm = MagicMock() node_llm.id = LLM_NODE_ID @@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..5a54b70bba --- /dev/null +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py @@ -0,0 +1,48 @@ +import pytest +from dify_trace_opik.config import OpikConfig +from pydantic import ValidationError + + +class TestOpikConfig: + """Test cases for OpikConfig""" + + def test_valid_config(self): + """Test valid Opik configuration""" + config = OpikConfig( + api_key="test_key", + project="test_project", + workspace="test_workspace", + url="https://custom.comet.com/opik/api/", + ) + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.workspace == "test_workspace" + assert config.url == "https://custom.comet.com/opik/api/" + + def test_default_values(self): + """Test default values are set correctly""" + config = OpikConfig() + assert config.api_key is None + assert config.project is None + assert config.workspace is None + assert config.url == "https://www.comet.com/opik/api/" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = OpikConfig(project="") + assert config.project == "Default Project" + + def test_url_validation_empty(self): + """Test URL validation with empty value""" + config = OpikConfig(url="") + assert config.url == "https://www.comet.com/opik/api/" + + def test_url_validation_missing_suffix(self): + """Test URL validation requires /api/ suffix""" + with pytest.raises(ValidationError, match="URL should end with /api/"): + OpikConfig(url="https://custom.comet.com/opik/") + + def test_url_validation_invalid_scheme(self): + """Test URL validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + OpikConfig(url="ftp://custom.comet.com/opik/api/") diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py similarity index 87% rename from api/tests/unit_tests/core/ops/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index ad9d0846be..2e0796c291 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -12,10 +12,12 @@ from __future__ import annotations import uuid from datetime import datetime +from typing import cast from unittest.mock import MagicMock, patch +from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo -from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid # A stable UUID4 used as the workflow_run_id throughout all tests. _WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" @@ -56,8 +58,8 @@ def _make_workflow_trace_info( def _make_opik_trace_instance() -> OpikDataTrace: """Construct an OpikDataTrace with the Opik SDK client mocked out.""" - with patch("core.ops.opik_trace.opik_trace.Opik"): - from core.ops.entities.config_entity import OpikConfig + with patch("dify_trace_opik.opik_trace.Opik"): + from dify_trace_opik.config import OpikConfig config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") instance = OpikDataTrace(config) @@ -68,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace: return instance +def _add_trace_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_trace) + + +def _add_span_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_span) + + # --------------------------------------------------------------------------- # _seed_to_uuid4 # --------------------------------------------------------------------------- @@ -133,10 +143,10 @@ class TestWorkflowTraceWithoutMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -154,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId: def test_root_span_is_created(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - assert instance.add_span.called + assert _add_span_mock(instance).called def test_root_span_id_matches_expected(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) expected = self._expected_root_span_id(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected def test_root_span_has_no_parent(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["parent_span_id"] is None def test_trace_name_is_workflow_trace(self): @@ -176,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId: trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_name_is_workflow_trace(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_has_workflow_tag(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert "workflow" in root_span_kwargs["tags"] def test_node_execution_spans_are_parented_to_root(self): @@ -213,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) # call_args_list[0] = root span, [1] = node execution span - assert instance.add_span.call_count == 2 - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + add_span = _add_span_mock(instance) + assert add_span.call_count == 2 + node_span_kwargs = add_span.call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id def test_node_span_not_parented_to_workflow_app_log_id(self): @@ -239,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] != old_parent_id def test_root_span_id_differs_from_trace_id(self): @@ -265,10 +276,10 @@ class TestWorkflowTraceWithMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -282,7 +293,7 @@ class TestWorkflowTraceWithMessageId: trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE def test_root_span_uses_workflow_run_id_directly(self): @@ -291,7 +302,7 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info) expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected_root_span_id def test_root_span_id_differs_from_no_message_id_case(self): @@ -325,5 +336,5 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info, node_executions=[node_exec]) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml new file mode 100644 index 0000000000..eab06fc708 --- /dev/null +++ b/api/providers/trace/trace-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-tencent" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Tencent APM)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py similarity index 100% rename from api/core/ops/tencent_trace/client.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py new file mode 100644 index 0000000000..398e6c55a8 --- /dev/null +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py @@ -0,0 +1,30 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig + + +class TencentConfig(BaseTracingConfig): + """ + Tencent APM tracing config + """ + + token: str + endpoint: str + service_name: str + + @field_validator("token") + @classmethod + def token_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("Token cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") + + @field_validator("service_name") + @classmethod + def service_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/entities/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py similarity index 100% rename from api/core/ops/tencent_trace/entities/semconv.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py similarity index 100% rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py similarity index 98% rename from api/core/ops/tencent_trace/span_builder.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py index f79095d966..763a85ffd7 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py @@ -6,8 +6,6 @@ import json import logging from datetime import datetime -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -16,7 +14,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_tencent.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, GEN_AI_IS_ENTRY, @@ -40,9 +39,10 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.utils import TencentTraceUtils -from core.rag.models.document import Document +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.utils import TencentTraceUtils +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py similarity index 94% rename from api/core/ops/tencent_trace/tencent_trace.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index 84f54d8a5a..a8c480e4a5 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -1,18 +1,12 @@ -""" -Tencent APM tracing implementation with separated concerns -""" +"""Tencent APM tracing with idempotent client cleanup.""" +import inspect import logging -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -23,12 +17,17 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.client import TencentTraceClient -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.span_builder import TencentSpanBuilder -from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from dify_trace_tencent.client import TencentTraceClient +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.span_builder import TencentSpanBuilder +from dify_trace_tencent.utils import TencentTraceUtils from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance): """ Tencent APM trace implementation with single responsibility principle. Acts as a coordinator that delegates specific tasks to specialized classes. + + The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen + explicitly in tests or implicitly during garbage collection, so shutdown + must be safe to call multiple times. """ + trace_client: TencentTraceClient + _closed: bool + def __init__(self, tencent_config: TencentConfig): super().__init__(tencent_config) + self._closed = False self.trace_client = TencentTraceClient( service_name=tencent_config.service_name, endpoint=tencent_config.endpoint, @@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance): except Exception: logger.debug("[Tencent APM] Failed to record message trace duration") - def __del__(self): - """Ensure proper cleanup on garbage collection.""" + def close(self) -> None: + """Synchronously and idempotently shutdown the underlying trace client.""" + if getattr(self, "_closed", False): + return + + self._closed = True + trace_client = getattr(self, "trace_client", None) + if trace_client is None: + return + try: - if hasattr(self, "trace_client"): - self.trace_client.shutdown() + shutdown_result = trace_client.shutdown() + if inspect.isawaitable(shutdown_result): + close_awaitable = getattr(shutdown_result, "close", None) + if callable(close_awaitable): + close_awaitable() except Exception: logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup") + + def __del__(self): + """Ensure best-effort cleanup on garbage collection without retrying shutdown.""" + self.close() diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py similarity index 100% rename from api/core/ops/tencent_trace/utils.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py similarity index 85% rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 870c18e53e..3cd918f408 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -5,15 +5,15 @@ from __future__ import annotations import sys import types from types import SimpleNamespace +from typing import Any, TypedDict, cast from unittest.mock import MagicMock import pytest +from dify_trace_tencent import client as client_module +from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version +from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event -from opentelemetry.trace import Status, StatusCode - -from core.ops.tencent_trace import client as client_module -from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData +from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] @@ -81,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality: self.kwargs = kwargs +class PatchedCoreComponents(TypedDict): + span_exporter: MagicMock + span_processor: MagicMock + tracer: MagicMock + span: MagicMock + tracer_provider: MagicMock + logger: MagicMock + trace_api: Any + + def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: """Drop fake metric modules into sys.modules so the client imports resolve.""" @@ -119,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture(autouse=True) -def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: +def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents: span_exporter = MagicMock(name="span_exporter") monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) @@ -169,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: } +def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext: + return SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + + def _build_client() -> TencentTraceClient: return TencentTraceClient( service_name="service", @@ -209,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st def test_resolve_grpc_target_handles_errors() -> None: - assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) + assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317) @pytest.mark.parametrize( @@ -249,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None: client.record_trace_duration(0.5) -def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: +def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() client.hist_llm_duration = MagicMock(name="hist_llm_duration") client.hist_llm_duration.record.side_effect = RuntimeError("boom") @@ -259,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, logger.debug.assert_called() -def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + ctx = _make_span_context(span_id=2) + span.get_span_context.return_value = ctx data = SpanData( trace_id=1, @@ -281,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, span.add_event.assert_called_once() span.set_status.assert_called_once() span.end.assert_called_once_with(end_time=20) - assert client.span_contexts[2] == "ctx" + assert client.span_contexts[2] == ctx -def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() - client.span_contexts[10] = "existing" + existing_context = _make_span_context(span_id=10) + client.span_contexts[10] = existing_context span = patch_core_components["span"] - span.get_span_context.return_value = "child" + span.get_span_context.return_value = _make_span_context(span_id=11) data = SpanData( trace_id=1, @@ -303,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[ client._create_and_export_span(data) trace_api = patch_core_components["trace_api"] - trace_api.NonRecordingSpan.assert_called_once_with("existing") + trace_api.NonRecordingSpan.assert_called_once_with(existing_context) trace_api.set_span_in_context.assert_called_once() -def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) client.tracer.start_span.side_effect = RuntimeError("boom") client._create_and_export_span( @@ -386,7 +407,7 @@ def test_get_project_url() -> None: assert client.get_project_url() == "https://console.cloud.tencent.com/apm" -def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: +def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span_processor = patch_core_components["span_processor"] tracer_provider = patch_core_components["tracer_provider"] @@ -402,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object metric_reader.shutdown.assert_called_once() -def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: +def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() meter_provider = meter_provider_instances[-1] meter_provider.shutdown.side_effect = RuntimeError("boom") + assert client.metric_reader is not None client.metric_reader.shutdown.side_effect = RuntimeError("boom") client.shutdown() @@ -434,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p assert client.metric_reader is None -def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: +def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None: client = _build_client() monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) @@ -455,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com logger.exception.assert_called_once() -def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) data = SpanData.model_construct( trace_id=1, @@ -486,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_llm_duration") client.hist_llm_duration = hist_mock - client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) + client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["foo"], str) assert attrs["bar"] == 2 @@ -497,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_trace_duration") client.hist_trace_duration = hist_mock - client.record_trace_duration(1.0, {"meta": object(), "ok": True}) + client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["meta"], str) assert attrs["ok"] is True @@ -513,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None: ], ) def test_record_methods_handle_exceptions( - method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] + method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents ) -> None: client = _build_client() hist_mock = MagicMock(name=attr_name) @@ -528,35 +550,38 @@ def test_record_methods_handle_exceptions( def test_metrics_initializes_grpc_metric_exporter() -> None: client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) + assert isinstance(exporter, DummyGrpcMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" - assert metric_reader.exporter.kwargs["insecure"] is False - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == "trace.example.com:4317" + assert exporter.kwargs["insecure"] is False + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyHttpMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) + assert isinstance(exporter, DummyHttpMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) + assert isinstance(exporter, DummyJsonMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" - assert "preferred_temporality" in metric_reader.exporter.kwargs + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" + assert "preferred_temporality" in exporter.kwargs def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: @@ -565,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) _ = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) - assert "preferred_temporality" not in metric_reader.exporter.kwargs + assert isinstance(exporter, DummyJsonMetricExporterNoTemporality) + assert "preferred_temporality" not in exporter.kwargs def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py similarity index 89% rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py index 696f859b6f..e850a801f3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py @@ -1,17 +1,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from opentelemetry.trace import StatusCode - -from core.ops.entities.trace_entity import ( - DatasetRetrievalTraceInfo, - MessageTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.ops.tencent_trace.entities.semconv import ( +from dify_trace_tencent.entities.semconv import ( GEN_AI_IS_ENTRY, GEN_AI_IS_STREAMING_REQUEST, GEN_AI_MODEL_NAME, @@ -25,13 +15,23 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from dify_trace_tencent.span_builder import TencentSpanBuilder +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) from core.rag.models.document import Document +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: def test_get_time_nanoseconds(self): - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: mock_convert.return_value = 123456789 dt = datetime.now() result = TencentSpanBuilder._get_time_nanoseconds(dt) @@ -48,7 +48,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {"answer": "world"} trace_info.metadata = {"conversation_id": "conv_id"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -70,7 +70,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {} trace_info.metadata = {} # No conversation_id - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 1 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -98,7 +98,7 @@ class TestTencentSpanBuilder: } node_execution.outputs = {"text": "world"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -123,7 +123,7 @@ class TestTencentSpanBuilder: "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, } - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -142,7 +142,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {"conversation_id": "conv_id"} trace_info.is_streaming_request = True - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -162,7 +162,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {} trace_info.is_streaming_request = False - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -182,7 +182,7 @@ class TestTencentSpanBuilder: trace_info.tool_inputs = {"i": 2} trace_info.tool_outputs = "result" - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 101 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) @@ -204,7 +204,7 @@ class TestTencentSpanBuilder: ) trace_info.documents = [doc] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -222,7 +222,7 @@ class TestTencentSpanBuilder: trace_info.end_time = datetime.now() trace_info.documents = [] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -264,7 +264,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -286,7 +286,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -307,7 +307,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -329,7 +329,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -350,7 +350,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 505 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py similarity index 86% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index f67abba807..54524b09ca 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -1,11 +1,12 @@ +import gc import logging -from unittest.mock import MagicMock, patch +import warnings +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.tencent_trace import TencentDataTrace -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -15,7 +16,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.tencent_trace import TencentDataTrace +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -28,19 +30,19 @@ def tencent_config(): @pytest.fixture def mock_trace_client(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock: yield mock @pytest.fixture def mock_span_builder(): - with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock: yield mock @pytest.fixture def mock_trace_utils(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock: yield mock @@ -198,9 +200,9 @@ class TestTencentDataTrace: trace_info.workflow_run_id = "run-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.workflow_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") @@ -230,9 +232,9 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.message_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") @@ -262,9 +264,9 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.tool_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") @@ -294,22 +296,22 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.dataset_retrieval_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") def test_suggested_question_trace(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") def test_suggested_question_trace_exception(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") @@ -342,7 +344,7 @@ class TestTencentDataTrace: with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) # The exception should be caught by the outer handler since convert_to_span_id is called first mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -351,7 +353,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -381,7 +383,7 @@ class TestTencentDataTrace: node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) assert result is None mock_log.assert_called_once() @@ -403,15 +405,13 @@ class TestTencentDataTrace: mock_executions = [MagicMock()] - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.side_effect = [app, account, tenant_join] - with patch( - "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" - ) as mock_repo: + with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo: mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) @@ -423,7 +423,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -432,14 +432,14 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() # Ensure init_app is mocked mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.return_value = None - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -449,8 +449,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() mock_db.engine = MagicMock() @@ -476,8 +476,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "unknown" mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") @@ -519,7 +519,7 @@ class TestTencentDataTrace: node.process_data = None node.outputs = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_llm_metrics(node) # Should not crash @@ -557,7 +557,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_llm_metrics(trace_info) # Should not crash @@ -609,7 +609,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_workflow_trace_duration(trace_info) def test_record_message_trace_duration(self, tencent_data_trace): @@ -631,16 +631,41 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_trace_duration(trace_info) - def test_del(self, tencent_data_trace): + def test_close(self, tencent_data_trace): client = tencent_data_trace.trace_client - tencent_data_trace.__del__() + tencent_data_trace.close() client.shutdown.assert_called_once() - def test_del_exception(self, tencent_data_trace): + def test_close_is_idempotent(self, tencent_data_trace): + client = tencent_data_trace.trace_client + + tencent_data_trace.close() + tencent_data_trace.close() + + client.shutdown.assert_called_once() + + def test_close_exception(self, tencent_data_trace): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.__del__() + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.close() mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") + + def test_close_handles_async_shutdown_mock(self, tencent_data_trace): + shutdown = AsyncMock() + tencent_data_trace.trace_client.shutdown = shutdown + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + tencent_data_trace.close() + gc.collect() + + shutdown.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) + and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py similarity index 88% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py index ef28d18e20..63c6d680d7 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py @@ -8,10 +8,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from dify_trace_tencent.utils import TencentTraceUtils from opentelemetry.trace import Link, TraceFlags -from core.ops.tencent_trace.utils import TencentTraceUtils - def test_convert_to_trace_id_with_valid_uuid() -> None: uuid_str = "12345678-1234-5678-1234-567812345678" @@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None: def test_convert_to_trace_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int uuid4_mock.assert_called_once() @@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: def test_convert_to_span_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") assert isinstance(span_id, int) uuid4_mock.assert_called_once() @@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: def test_generate_span_id_skips_invalid_span_id() -> None: with patch( - "core.ops.tencent_trace.utils.random.getrandbits", + "dify_trace_tencent.utils.random.getrandbits", side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], ) as bits_mock: assert TencentTraceUtils.generate_span_id() == 42 @@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) expected = int(fixed.timestamp() * 1e9) - with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + with patch("dify_trace_tencent.utils.datetime") as datetime_mock: datetime_mock.now.return_value = fixed assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected datetime_mock.now.assert_called_once() @@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i @pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] assert link.context.trace_id == fallback_uuid.int uuid4_mock.assert_called_once() diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml new file mode 100644 index 0000000000..ba449f2a93 --- /dev/null +++ b/api/providers/trace/trace-weave/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-weave" +version = "0.0.1" +dependencies = [ + "weave>=0.52.36", +] +description = "Dify ops tracing provider (Weave)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py new file mode 100644 index 0000000000..5942bd57fe --- /dev/null +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py @@ -0,0 +1,29 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class WeaveConfig(BaseTracingConfig): + """ + Model class for Weave tracing config. + """ + + api_key: str + entity: str | None = None + project: str + endpoint: str = "https://trace.wandb.ai" + host: str | None = None + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # Weave only allows HTTPS for endpoint + return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + if v is not None and v.strip() != "": + return validate_url(v, v, allowed_schemes=("https", "http")) + return v diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py similarity index 100% rename from api/core/ops/weave_trace/entities/weave_trace_entity.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py similarity index 99% rename from api/core/ops/weave_trace/weave_trace.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py index 8d9ba4694d..4292cbf0f1 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py @@ -6,7 +6,6 @@ from typing import Any, cast import wandb import weave -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -18,7 +17,6 @@ from weave.trace_server.trace_server_interface import ( ) from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -30,9 +28,11 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..377c768198 --- /dev/null +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -0,0 +1,61 @@ +import pytest +from dify_trace_weave.config import WeaveConfig +from pydantic import ValidationError + + +class TestWeaveConfig: + """Test cases for WeaveConfig""" + + def test_valid_config(self): + """Test valid Weave configuration""" + config = WeaveConfig( + api_key="test_key", + entity="test_entity", + project="test_project", + endpoint="https://custom.wandb.ai", + host="https://custom.host.com", + ) + assert config.api_key == "test_key" + assert config.entity == "test_entity" + assert config.project == "test_project" + assert config.endpoint == "https://custom.wandb.ai" + assert config.host == "https://custom.host.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = WeaveConfig(api_key="key", project="project") + assert config.entity is None + assert config.endpoint == "https://trace.wandb.ai" + assert config.host is None + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + WeaveConfig.model_validate({}) + + with pytest.raises(ValidationError): + WeaveConfig.model_validate({"api_key": "key"}) + + with pytest.raises(ValidationError): + WeaveConfig.model_validate({"project": "project"}) + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") + + def test_host_validation_optional(self): + """Test host validation is optional but validates when provided""" + config = WeaveConfig(api_key="key", project="project", host=None) + assert config.host is None + + config = WeaveConfig(api_key="key", project="project", host="") + assert config.host == "" + + config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") + assert config.host == "https://valid.host.com" + + def test_host_validation_invalid_scheme(self): + """Test host validation rejects invalid schemes when provided""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py similarity index 97% rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py index 5014f40afc..6028d0c550 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" +"""Comprehensive tests for dify_trace_weave.weave_trace module.""" from __future__ import annotations @@ -7,10 +7,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel +from dify_trace_weave.weave_trace import WeaveDataTrace from weave.trace_server.trace_server_interface import TraceStatus -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -21,8 +22,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.ops.weave_trace.weave_trace import WeaveDataTrace +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -191,14 +191,14 @@ def _make_node(**overrides): @pytest.fixture def mock_wandb(): - with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + with patch("dify_trace_weave.weave_trace.wandb") as mock: mock.login.return_value = True yield mock @pytest.fixture def mock_weave(): - with patch("core.ops.weave_trace.weave_trace.weave") as mock: + with patch("dify_trace_weave.weave_trace.weave") as mock: client = MagicMock() client.entity = "my-entity" client.project = "my-project" @@ -307,7 +307,7 @@ class TestGetProjectUrl: monkeypatch.setattr(trace_instance, "entity", None) monkeypatch.setattr(trace_instance, "project_name", None) # Force an error by making string formatting fail - with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + with patch("dify_trace_weave.weave_trace.logger") as mock_logger: # Simulate exception via property original_entity = trace_instance.entity trace_instance.entity = None @@ -594,9 +594,9 @@ class TestWorkflowTrace: mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) return repo def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): @@ -703,8 +703,8 @@ class TestWorkflowTrace: def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): """Raises ValueError when app_id is missing from metadata.""" - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) trace_info = _make_workflow_trace_info( message_id=None, @@ -802,7 +802,7 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.get", + "dify_trace_weave.weave_trace.db.session.get", lambda model, pk: None, ) @@ -824,7 +824,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -846,7 +846,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = end_user - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -866,7 +866,7 @@ class TestMessageTrace: """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -884,7 +884,7 @@ class TestMessageTrace: """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -899,7 +899,7 @@ class TestMessageTrace: """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() diff --git a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py index b2908ebdae..11398efb58 100644 --- a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py @@ -1,6 +1,6 @@ import json import uuid -from collections.abc import Iterator +from collections.abc import Generator # Added Generator from contextlib import contextmanager from typing import Any @@ -75,7 +75,7 @@ class AnalyticdbVectorBySql: ) @contextmanager - def _get_cursor(self) -> Iterator[Any]: + def _get_cursor(self) -> Generator[Any, None, None]: # Changed from Iterator[Any] assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() diff --git a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py index 815ac30c0b..bab176e285 100644 --- a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py @@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector): auth = PasswordAuthenticator(config.user, config.password) options = ClusterOptions(auth) - self._cluster = Cluster(config.connection_string, options) + self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType] self._bucket = self._cluster.bucket(config.bucket_name) self._scope = self._bucket.scope(config.scope_name) self._bucket_name = config.bucket_name @@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) try: - CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) + CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue] search_iter = self._scope.search( self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) ) diff --git a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py index 87b9d813ec..e2f390402a 100644 --- a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py @@ -23,7 +23,7 @@ class ElasticSearchJaVector(ElasticSearchVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index 46f3224a95..823b877707 100644 --- a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from packaging import version from pydantic import BaseModel, model_validator @@ -92,7 +92,7 @@ class MilvusVector(BaseVector): def _load_collection_fields(self, fields: list[str] | None = None): if fields is None: # Load collection fields from remote server - collection_info = self._client.describe_collection(self._collection_name) + collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name)) fields = [field["name"] for field in collection_info["fields"]] # Since primary field is auto-id, no need to track it self._fields = [f for f in fields if f != Field.PRIMARY_KEY] @@ -106,7 +106,8 @@ class MilvusVector(BaseVector): return False try: - milvus_version = self._client.get_server_version() + milvus_version_raw = self._client.get_server_version() + milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw) # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility if "Zilliz Cloud" in milvus_version: return True diff --git a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py index 70377c82c8..5d9ab38529 100644 --- a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py +++ b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py @@ -3,7 +3,7 @@ import json import logging import re import uuid -from typing import Any +from typing import Any, TypedDict import jieba.posseg as pseg # type: ignore import numpy @@ -25,6 +25,18 @@ logger = logging.getLogger(__name__) oracledb.defaults.fetch_lobs = False +class _OraclePoolParams(TypedDict, total=False): + user: str + password: str + dsn: str + min: int + max: int + increment: int + config_dir: str | None + wallet_location: str | None + wallet_password: str | None + + class OracleVectorConfig(BaseModel): user: str password: str @@ -127,22 +139,18 @@ class OracleVector(BaseVector): return connection def _create_connection_pool(self, config: OracleVectorConfig): - pool_params = { - "user": config.user, - "password": config.password, - "dsn": config.dsn, - "min": 1, - "max": 5, - "increment": 1, - } + pool_params = _OraclePoolParams( + user=config.user, + password=config.password, + dsn=config.dsn, + min=1, + max=5, + increment=1, + ) if config.is_autonomous: - pool_params.update( - { - "config_dir": config.config_dir, - "wallet_location": config.wallet_location, - "wallet_password": config.wallet_password, - } - ) + pool_params["config_dir"] = config.config_dir + pool_params["wallet_location"] = config.wallet_location + pool_params["wallet_password"] = config.wallet_password return oracledb.create_pool(**pool_params) def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): diff --git a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py index c335bc610d..e087ec30a5 100644 --- a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py @@ -1,3 +1,4 @@ +from typing import Any from uuid import UUID from numpy import ndarray @@ -8,5 +9,5 @@ class CollectionORM(DeclarativeBase): __tablename__: str id: Mapped[UUID] text: Mapped[str] - meta: Mapped[dict] + meta: Mapped[dict[str, Any]] vector: Mapped[ndarray] diff --git a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py index 84b5eba0ac..9c721c8bde 100644 --- a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py @@ -67,7 +67,7 @@ class PGVectoRS(BaseVector): primary_key=True, ) text: Mapped[str] - meta: Mapped[dict] = mapped_column(postgresql.JSONB) + meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB) vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) self._table = _Table diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py index bb8a580ebf..abca55f540 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -1,4 +1,5 @@ import json +import logging import os import uuid from collections.abc import Generator, Iterable, Sequence @@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any import httpx import qdrant_client + +logger = logging.getLogger(__name__) from flask import current_app from httpx import DigestAuth from pydantic import BaseModel @@ -421,13 +424,16 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: + logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id) stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if not tidb_auth_binding: + logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id) with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if tidb_auth_binding: + logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" else: @@ -437,11 +443,18 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): .limit(1) ) if idle_tidb_auth_binding: + logger.info( + "Assigning idle cluster %s to tenant %s", + idle_tidb_auth_binding.cluster_id, + dataset.tenant_id, + ) idle_tidb_auth_binding.active = True idle_tidb_auth_binding.tenant_id = dataset.tenant_id db.session.commit() + tidb_auth_binding = idle_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" else: + logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id) new_cluster = TidbService.create_tidb_serverless_cluster( dify_config.TIDB_PROJECT_ID or "", dify_config.TIDB_API_URL or "", @@ -450,21 +463,39 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): dify_config.TIDB_PRIVATE_KEY or "", dify_config.TIDB_REGION or "", ) + logger.info( + "New cluster created: cluster_id=%s, qdrant_endpoint=%s", + new_cluster["cluster_id"], + new_cluster.get("qdrant_endpoint"), + ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), tenant_id=dataset.tenant_id, active=True, status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() + tidb_auth_binding = new_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" else: + logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + qdrant_url = ( + (tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or "" + ) + logger.info( + "Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)", + qdrant_url, + tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None, + dify_config.TIDB_ON_QDRANT_URL, + ) + if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix @@ -479,7 +510,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL or "", + endpoint=qdrant_url, api_key=TIDB_ON_QDRANT_API_KEY, root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py index 37114be6e7..ece061db67 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py @@ -1,3 +1,4 @@ +import logging import time import uuid from collections.abc import Sequence @@ -12,6 +13,8 @@ from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus +logger = logging.getLogger(__name__) + # Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn _tidb_http_client: httpx.Client = get_pooled_http_client( "tidb:cloud", @@ -20,6 +23,46 @@ _tidb_http_client: httpx.Client = get_pooled_http_client( class TidbService: + @staticmethod + def extract_qdrant_endpoint(cluster_response: dict) -> str | None: + """Extract the qdrant endpoint URL from a Get Cluster API response. + + Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``), + prepends ``qdrant-`` and wraps it as an ``https://`` URL. + """ + endpoints = cluster_response.get("endpoints") or {} + public = endpoints.get("public") or {} + host = public.get("host") + if host: + return f"https://qdrant-{host}" + return None + + @staticmethod + def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None: + """Call Get Cluster API and extract the qdrant endpoint. + + Use ``extract_qdrant_endpoint`` instead when you already have + the cluster response to avoid a redundant API call. + """ + try: + logger.info("Fetching qdrant endpoint for cluster %s", cluster_id) + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if not cluster_response: + logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id) + return None + qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response) + if qdrant_url: + logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url) + return qdrant_url + logger.warning( + "No endpoints.public.host found for cluster %s, response keys: %s", + cluster_id, + list(cluster_response.keys()), + ) + except Exception: + logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id) + return None + @staticmethod def create_tidb_serverless_cluster( project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str @@ -57,6 +100,7 @@ class TidbService: "rootPassword": password, } + logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region) response = _tidb_http_client.post( f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) ) @@ -64,21 +108,39 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_id = response_data["clusterId"] + logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id) retry_count = 0 max_retries = 30 while retry_count < max_retries: cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) if cluster_response["state"] == "ACTIVE": user_prefix = cluster_response["userPrefix"] + qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response) + logger.info( + "Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s", + cluster_id, + user_prefix, + qdrant_endpoint, + ) return { "cluster_id": cluster_id, "cluster_name": display_name, "account": f"{user_prefix}.root", "password": password, + "qdrant_endpoint": qdrant_endpoint, } - time.sleep(30) # wait 30 seconds before retrying + logger.info( + "Cluster %s state=%s, retry %d/%d", + cluster_id, + cluster_response["state"], + retry_count + 1, + max_retries, + ) + time.sleep(30) retry_count += 1 + logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries) else: + logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() @staticmethod @@ -243,19 +305,29 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_infos = [] + logger.info("Batch created %d clusters", len(response_data.get("clusters", []))) for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" cached_password = redis_client.get(cache_key) if not cached_password: + logger.warning("No cached password for cluster %s, skipping", item["displayName"]) continue + qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) + logger.info( + "Batch cluster %s: qdrant_endpoint=%s", + item["clusterId"], + qdrant_endpoint, + ) cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", "password": cached_password.decode("utf-8"), + "qdrant_endpoint": qdrant_endpoint, } cluster_infos.append(cluster_info) return cluster_infos else: + logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() return [] diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py index 3e9229fea5..76802de62e 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py @@ -114,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds: assert exc_info.value.status_code == 500 - def test_delete_by_ids_with_large_batch(self, vector_instance): - """Test deletion with a large batch of IDs.""" - # Create 1000 IDs + def test_delete_by_ids_with_exactly_1000(self, vector_instance): + """Test deletion with exactly 1000 IDs triggers a single batch.""" ids = [f"doc_{i}" for i in range(1000)] vector_instance.delete_by_ids(ids) - # Verify single delete call with all IDs vector_instance._client.delete.assert_called_once() call_args = vector_instance._client.delete.call_args @@ -129,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds: filter_obj = filter_selector.filter field_condition = filter_obj.must[0] - # Verify all 1000 IDs are in the batch assert len(field_condition.match.any) == 1000 assert "doc_0" in field_condition.match.any assert "doc_999" in field_condition.match.any + def test_delete_by_ids_splits_into_batches(self, vector_instance): + """Test deletion with >1000 IDs triggers multiple batched calls.""" + ids = [f"doc_{i}" for i in range(2500)] + + vector_instance.delete_by_ids(ids) + + assert vector_instance._client.delete.call_count == 3 + + batches = [] + for call in vector_instance._client.delete.call_args_list: + filter_selector = call[1]["points_selector"] + field_condition = filter_selector.filter.must[0] + batches.append(field_condition.match.any) + + assert len(batches[0]) == 1000 + assert len(batches[1]) == 1000 + assert len(batches[2]) == 500 + def test_delete_by_ids_filter_structure(self, vector_instance): """Test that the filter structure is correctly constructed.""" ids = ["doc1", "doc2"] @@ -157,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds: # Verify MatchAny structure assert isinstance(field_condition.match, rest.MatchAny) assert field_condition.match.any == ids + + +class TestInitVectorEndpointSelection: + """Test that init_vector selects the correct qdrant endpoint. + + We avoid importing the full module (which triggers Flask app context) + by testing the endpoint selection logic directly on TidbOnQdrantConfig. + """ + + def test_uses_binding_endpoint_when_present(self): + binding_endpoint = "https://qdrant-custom.tidb.com" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-custom.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-custom.tidb.com" + + def test_falls_back_to_global_when_binding_endpoint_is_none(self): + binding_endpoint = None + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-global.tidb.com" + + def test_falls_back_to_empty_when_both_none(self): + binding_endpoint = None + global_url = None + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "" + + def test_binding_endpoint_takes_precedence_over_global(self): + binding_endpoint = "https://qdrant-ap-southeast.tidb.com" + global_url = "https://qdrant-us-east.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-ap-southeast.tidb.com" + + def test_empty_string_binding_endpoint_falls_back_to_global(self): + binding_endpoint = "" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py new file mode 100644 index 0000000000..c1ffbacbbc --- /dev/null +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py @@ -0,0 +1,218 @@ +from unittest.mock import MagicMock, patch + +import pytest +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService + + +class TestExtractQdrantEndpoint: + """Unit tests for TidbService.extract_qdrant_endpoint.""" + + def test_returns_endpoint_when_host_present(self): + response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}} + result = TidbService.extract_qdrant_endpoint(response) + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + def test_returns_none_when_host_missing(self): + response = {"endpoints": {"public": {}}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_public_missing(self): + response = {"endpoints": {}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_endpoints_missing(self): + assert TidbService.extract_qdrant_endpoint({}) is None + + +class TestFetchQdrantEndpoint: + """Unit tests for TidbService.fetch_qdrant_endpoint.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_endpoint_when_host_present(self, mock_get_cluster): + mock_get_cluster.return_value = { + "endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}} + } + result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster): + mock_get_cluster.return_value = None + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_host_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {"endpoints": {"public": {}}} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_endpoints_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_on_exception(self, mock_get_cluster): + mock_get_cluster.side_effect = RuntimeError("network error") + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + +class TestCreateTidbServerlessClusterQdrantEndpoint: + """Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = { + "state": "ACTIVE", + "userPrefix": "pfx", + "endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}}, + } + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"} + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] is None + + +class TestBatchCreateTidbServerlessClusterQdrantEndpoint: + """Verify that batch_create includes qdrant_endpoint per cluster.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + cluster_name = "abc123" + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = b"password123" + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 1 + assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + +class TestCreateTidbServerlessClusterRetry: + """Cover retry/logging paths in create_tidb_serverless_cluster.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.side_effect = [ + {"state": "CREATING", "userPrefix": ""}, + {"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}}, + ] + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com" + assert mock_get_cluster.call_count == 2 + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""} + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is None + + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=400, text="Bad Request") + mock_response.raise_for_status.side_effect = Exception("HTTP 400") + mock_http.post.return_value = mock_response + + with pytest.raises(Exception, match="HTTP 400"): + TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + +class TestBatchCreateEdgeCases: + """Cover logging/edge-case branches in batch_create.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = None + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 0 + mock_fetch_ep.assert_not_called() + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=500, text="Server Error") + mock_response.raise_for_status.side_effect = Exception("HTTP 500") + mock_http.post.return_value = mock_response + mock_redis.setex = MagicMock() + + with pytest.raises(Exception, match="HTTP 500"): + TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) diff --git a/api/pyproject.toml b/api/pyproject.toml index 3b7e5f8e1f..2587d9e0bf 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -4,92 +4,50 @@ version = "1.13.3" requires-python = "~=3.12.0" dependencies = [ - "aliyun-log-python-sdk~=0.9.44", - "arize-phoenix-otel~=0.15.0", - "azure-identity==1.25.3", - "beautifulsoup4==4.14.3", - "boto3==1.42.88", - "bs4~=0.0.1", - "cachetools~=7.0.5", - "celery~=5.6.3", - "charset-normalizer>=3.4.7", - "flask~=3.1.3", - "flask-compress>=1.24,<1.25", - "flask-cors~=6.0.2", - "flask-login~=0.6.3", - "flask-migrate~=4.1.0", - "flask-orjson~=2.0.0", - "flask-sqlalchemy~=3.1.1", - "gevent~=26.4.0", - "gmpy2~=2.3.0", - "google-api-core>=2.30.3", - "google-api-python-client==2.194.0", - "google-auth>=2.49.2", - "google-auth-httplib2==0.3.1", - "google-cloud-aiplatform>=1.147.0", - "googleapis-common-protos>=1.74.0", - "graphon>=0.1.2", - "gunicorn~=25.3.0", - "httpx[socks]~=0.28.1", - "jieba==0.42.1", - "json-repair>=0.59.2", - "langfuse>=4.2.0,<5.0.0", - "langsmith~=0.7.30", - "markdown~=3.10.2", - "mlflow-skinny>=3.11.1", - "numpy~=2.4.4", - "openpyxl~=3.1.5", - "opik~=1.11.2", - "litellm==1.83.0", # Pinned to avoid madoka dependency issue - "opentelemetry-api==1.41.0", - "opentelemetry-distro==0.62b0", - "opentelemetry-exporter-otlp==1.41.0", - "opentelemetry-exporter-otlp-proto-common==1.41.0", - "opentelemetry-exporter-otlp-proto-grpc==1.41.0", - "opentelemetry-exporter-otlp-proto-http==1.41.0", - "opentelemetry-instrumentation==0.62b0", - "opentelemetry-instrumentation-celery==0.62b0", - "opentelemetry-instrumentation-flask==0.62b0", - "opentelemetry-instrumentation-httpx==0.62b0", - "opentelemetry-instrumentation-redis==0.62b0", - "opentelemetry-instrumentation-sqlalchemy==0.62b0", - "opentelemetry-propagator-b3==1.41.0", - "opentelemetry-proto==1.41.0", - "opentelemetry-sdk==1.41.0", - "opentelemetry-semantic-conventions==0.62b0", - "opentelemetry-util-http==0.62b0", - "pandas[excel,output-formatting,performance]~=3.0.2", - "psycogreen~=1.0.2", - "psycopg2-binary~=2.9.11", - "pycryptodome==3.23.0", - "pydantic~=2.12.5", - "pydantic-settings~=2.13.1", - "pyjwt~=2.12.1", - "pypdfium2==5.6.0", - "python-docx~=1.2.0", - "python-dotenv==1.2.2", - "pyyaml~=6.0.1", - "readabilipy~=0.3.0", - "redis[hiredis]~=7.4.0", - "resend~=2.27.0", - "sentry-sdk[flask]~=2.57.0", - "sqlalchemy~=2.0.49", - "starlette==1.0.0", - "tiktoken~=0.12.0", - "transformers~=5.3.0", - "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", - "pypandoc~=1.13", - "yarl~=1.23.0", - "sseclient-py~=1.9.0", - "httpx-sse~=0.4.0", - "sendgrid~=6.12.5", - "flask-restx~=1.3.2", - "packaging~=26.0", + # Legacy: mature and widely deployed + "bleach>=6.3.0", + "boto3>=1.42.96", + "celery>=5.6.3", "croniter>=6.2.2", - "apscheduler>=3.11.2", - "weave>=0.52.36", - "fastopenapi[flask]>=0.7.0", - "bleach~=6.3.0", + "flask>=3.1.3,<4.0.0", + "flask-cors>=6.0.2", + "gevent>=26.4.0", + "gevent-websocket>=0.10.1", + "gmpy2>=2.3.0", + "google-api-python-client>=2.194.0", + "gunicorn>=25.3.0", + "psycogreen>=1.0.2", + "psycopg2-binary>=2.9.12", + "python-socketio>=5.13.0", + "redis[hiredis]>=7.4.0", + "sendgrid>=6.12.5", + "sseclient-py>=1.8.0", + + # Stable: production-proven, cap below the next major + "aliyun-log-python-sdk>=0.9.44,<1.0.0", + "azure-identity>=1.25.3,<2.0.0", + "flask-compress>=1.24,<2.0.0", + "flask-login>=0.6.3,<1.0.0", + "flask-migrate>=4.1.0,<5.0.0", + "flask-orjson>=2.0.0,<3.0.0", + "flask-restx>=1.3.2,<2.0.0", + "google-cloud-aiplatform>=1.148.1,<2.0.0", + "httpx[socks]>=0.28.1,<1.0.0", + "opentelemetry-distro>=0.62b1,<1.0.0", + "opentelemetry-instrumentation-celery>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-flask>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-httpx>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-redis>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-sqlalchemy>=0.62b0,<1.0.0", + "opentelemetry-propagator-b3>=1.41.1,<2.0.0", + "readabilipy>=0.3.0,<1.0.0", + "resend>=2.27.0,<3.0.0", + + # Emerging: newer and fast-moving, use compatible pins + "fastopenapi[flask]~=0.7.0", + "graphon~=0.2.2", + "httpx-sse~=0.4.0", + "json-repair~=0.59.4", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -98,8 +56,8 @@ dependencies = [ packages = [] [tool.uv.workspace] -members = ["providers/vdb/*"] -exclude = ["providers/vdb/__pycache__"] +members = ["providers/vdb/*", "providers/trace/*"] +exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"] [tool.uv.sources] dify-vdb-alibabacloud-mysql = { workspace = true } @@ -132,10 +90,21 @@ dify-vdb-upstash = { workspace = true } dify-vdb-vastbase = { workspace = true } dify-vdb-vikingdb = { workspace = true } dify-vdb-weaviate = { workspace = true } +dify-trace-aliyun = { workspace = true } +dify-trace-arize-phoenix = { workspace = true } +dify-trace-langfuse = { workspace = true } +dify-trace-langsmith = { workspace = true } +dify-trace-mlflow = { workspace = true } +dify-trace-opik = { workspace = true } +dify-trace-tencent = { workspace = true } +dify-trace-weave = { workspace = true } [tool.uv] -default-groups = ["storage", "tools", "vdb-all"] +default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false +override-dependencies = [ + "pyarrow>=18.0.0", +] [dependency-groups] @@ -144,70 +113,69 @@ package = false # Required for development and running tests ############################################################ dev = [ - "coverage~=7.13.4", - "dotenv-linter~=0.7.0", - "faker~=40.13.0", - "lxml-stubs~=0.5.1", - "basedpyright~=1.39.0", - "ruff~=0.15.10", - "pytest~=9.0.3", - "pytest-benchmark~=5.2.3", - "pytest-cov~=7.1.0", - "pytest-env~=1.6.0", - "pytest-mock~=3.15.1", - "testcontainers~=4.14.2", - "types-aiofiles~=25.1.0", - "types-beautifulsoup4~=4.12.0", - "types-cachetools~=6.2.0", - "types-colorama~=0.4.15", - "types-defusedxml~=0.7.0", - "types-deprecated~=1.3.1", - "types-docutils~=0.22.3", - "types-flask-cors~=6.0.0", - "types-flask-migrate~=4.1.0", - "types-gevent~=26.4.0", - "types-greenlet~=3.4.0", - "types-html5lib~=1.1.11", - "types-markdown~=3.10.2", - "types-oauthlib~=3.3.0", - "types-objgraph~=3.6.0", - "types-olefile~=0.47.0", - "types-openpyxl~=3.1.5", - "types-pexpect~=4.9.0", - "types-protobuf~=7.34.1", - "types-psutil~=7.2.2", - "types-psycopg2~=2.9.21", - "types-pygments~=2.20.0", - "types-pymysql~=1.1.0", - "types-python-dateutil~=2.9.0", - "types-pywin32~=311.0.0", - "types-pyyaml~=6.0.12", - "types-regex~=2026.4.4", - "types-shapely~=2.1.0", + "coverage>=7.13.4", + "dotenv-linter>=0.7.0", + "faker>=40.15.0", + "lxml-stubs>=0.5.1", + "basedpyright>=1.39.3", + "ruff>=0.15.12", + "pytest>=9.0.3", + "pytest-benchmark>=5.2.3", + "pytest-cov>=7.1.0", + "pytest-env>=1.6.0", + "pytest-mock>=3.15.1", + "testcontainers>=4.14.2", + "types-aiofiles>=25.1.0", + "types-beautifulsoup4>=4.12.0", + "types-cachetools>=6.2.0", + "types-colorama>=0.4.15", + "types-defusedxml>=0.7.0", + "types-deprecated>=1.3.1", + "types-docutils>=0.22.3", + "types-flask-cors>=6.0.0", + "types-flask-migrate>=4.1.0", + "types-gevent>=26.4.0", + "types-greenlet>=3.4.0", + "types-html5lib>=1.1.11", + "types-markdown>=3.10.2", + "types-oauthlib>=3.3.0", + "types-objgraph>=3.6.0", + "types-olefile>=0.47.0", + "types-openpyxl>=3.1.5", + "types-pexpect>=4.9.0", + "types-protobuf>=7.34.1", + "types-psutil>=7.2.2", + "types-psycopg2>=2.9.21.20260422", + "types-pygments>=2.20.0", + "types-pymysql>=1.1.0", + "types-python-dateutil>=2.9.0", + "types-pywin32>=311.0.0", + "types-pyyaml>=6.0.12", + "types-regex>=2026.4.4", + "types-shapely>=2.1.0", "types-simplejson>=3.20.0.20260408", "types-six>=1.17.0.20260408", "types-tensorflow>=2.18.0.20260408", "types-tqdm>=4.67.3.20260408", "types-ujson>=5.10.0", - "boto3-stubs>=1.42.88", + "boto3-stubs>=1.42.96", "types-jmespath>=1.1.0.20260408", - "hypothesis>=6.151.12", + "hypothesis>=6.152.3", "types_pyOpenSSL>=24.1.0", "types_cffi>=2.0.0.20260408", "types_setuptools>=82.0.0.20260408", - "pandas-stubs~=3.0.0", - "scipy-stubs>=1.15.3.0", + "pandas-stubs>=3.0.0", + "scipy-stubs>=1.17.1.4", "types-python-http-client>=3.3.7.20260408", "import-linter>=2.3", "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", - "mypy~=1.20.1", + "mypy>=1.20.2", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. - "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.60.0", - "xinference-client~=2.4.0", + "pyrefly>=0.62.0", + "xinference-client>=2.7.0", ] ############################################################ @@ -215,21 +183,21 @@ dev = [ # Required for storage clients ############################################################ storage = [ - "azure-storage-blob==12.28.0", - "bce-python-sdk~=0.9.69", - "cos-python-sdk-v5==1.9.41", - "esdk-obs-python==3.26.2", + "azure-storage-blob>=12.28.0", + "bce-python-sdk>=0.9.70", + "cos-python-sdk-v5>=1.9.42", + "esdk-obs-python>=3.22.2", "google-cloud-storage>=3.10.1", - "opendal~=0.46.0", - "oss2==2.19.1", - "supabase~=2.18.1", - "tos~=2.9.0", + "opendal>=0.46.0", + "oss2>=2.19.1", + "supabase>=2.29.0", + "tos>=2.9.0", ] ############################################################ # [ Tools ] dependency group ############################################################ -tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] +tools = ["cloudscraper>=1.2.71", "nltk>=3.9.1"] ############################################################ # [ VDB ] workspace plugins — hollow packages under providers/vdb/* @@ -299,7 +267,26 @@ vdb-vastbase = ["dify-vdb-vastbase"] vdb-vikingdb = ["dify-vdb-vikingdb"] vdb-weaviate = ["dify-vdb-weaviate"] # Optional client used by some tests / integrations (not a vector backend plugin) -vdb-xinference = ["xinference-client~=2.4.0"] +vdb-xinference = ["xinference-client>=2.7.0"] + +trace-all = [ + "dify-trace-aliyun", + "dify-trace-arize-phoenix", + "dify-trace-langfuse", + "dify-trace-langsmith", + "dify-trace-mlflow", + "dify-trace-opik", + "dify-trace-tencent", + "dify-trace-weave", +] +trace-aliyun = ["dify-trace-aliyun"] +trace-arize-phoenix = ["dify-trace-arize-phoenix"] +trace-langfuse = ["dify-trace-langfuse"] +trace-langsmith = ["dify-trace-langsmith"] +trace-mlflow = ["dify-trace-mlflow"] +trace-opik = ["dify-trace-opik"] +trace-tencent = ["dify-trace-tencent"] +trace-weave = ["dify-trace-weave"] [tool.pyrefly] project-includes = ["."] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 3e5ece1fcf..fbbca24558 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -34,12 +34,12 @@ core/external_data_tool/api/api.py core/llm_generator/llm_generator.py core/llm_generator/output_parser/structured_output.py core/mcp/mcp_client.py -core/ops/aliyun_trace/data_exporter/traceclient.py -core/ops/arize_phoenix_trace/arize_phoenix_trace.py -core/ops/mlflow_trace/mlflow_trace.py +providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py +providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py core/ops/ops_trace_manager.py -core/ops/tencent_trace/client.py -core/ops/tencent_trace/utils.py +providers/trace/trace-tencent/src/dify_trace_tencent/client.py +providers/trace/trace-tencent/src/dify_trace_tencent/utils.py core/plugin/backwards_invocation/base.py core/plugin/backwards_invocation/model.py core/prompt/utils/extract_thread_messages.py diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index c4582e891d..ac0e2a3a53 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -5,7 +5,8 @@ ".venv", "migrations/", "core/rag", - "providers/", + "providers/vdb/", + "providers/trace/*/tests", ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 100589804c..72b38e7906 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol, TypedDict -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index d5c6a203b1..44735eb769 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index b760696c5e..71a2554a60 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,21 +28,21 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm +from models.human_input import HumanInputForm, HumanInputFormRecipient from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -63,6 +63,7 @@ class _WorkflowRunError(Exception): def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, + recipients: Sequence[HumanInputFormRecipient] = (), ) -> HumanInputRequired: form_content = "" inputs = [] @@ -89,7 +90,7 @@ def _build_human_input_required_reason( resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - return HumanInputRequired( + reason = HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, @@ -98,6 +99,7 @@ def _build_human_input_required_reason( node_title=node_title, resolved_default_values=resolved_default_values, ) + return reason class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): @@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form + recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {} + if form_ids: + recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(recipient_stmt).all(): + recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient) pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - pause_reasons.append(_build_human_input_required_reason(reason, form_model)) + pause_reasons.append( + _build_human_input_required_reason( + reason, + form_model, + recipients_by_form_id.get(reason.form_id, ()), + ) + ) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index feba5f7eb6..67f8795d3f 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,9 +7,6 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -21,6 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/repositories/workflow_collaboration_repository.py b/api/repositories/workflow_collaboration_repository.py new file mode 100644 index 0000000000..000f80496d --- /dev/null +++ b/api/repositories/workflow_collaboration_repository.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +from typing import TypedDict + +from extensions.ext_redis import redis_client + +SESSION_STATE_TTL_SECONDS = 3600 +WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:" +WORKFLOW_LEADER_PREFIX = "workflow_leader:" +WS_SID_MAP_PREFIX = "ws_sid_map:" + + +class WorkflowSessionInfo(TypedDict): + user_id: str + username: str + avatar: str | None + sid: str + connected_at: int + + +class SidMapping(TypedDict): + workflow_id: str + user_id: str + + +class WorkflowCollaborationRepository: + def __init__(self) -> None: + self._redis = redis_client + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(redis_client={self._redis})" + + @staticmethod + def workflow_key(workflow_id: str) -> str: + return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}" + + @staticmethod + def leader_key(workflow_id: str) -> str: + return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}" + + @staticmethod + def sid_key(sid: str) -> str: + return f"{WS_SID_MAP_PREFIX}{sid}" + + @staticmethod + def _decode(value: str | bytes | None) -> str | None: + if value is None: + return None + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + workflow_key = self.workflow_key(workflow_id) + sid_key = self.sid_key(sid) + if self._redis.exists(workflow_key): + self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS) + if self._redis.exists(sid_key): + self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS) + + def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None: + workflow_key = self.workflow_key(workflow_id) + self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info)) + self._redis.set( + self.sid_key(session_info["sid"]), + json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}), + ex=SESSION_STATE_TTL_SECONDS, + ) + self.refresh_session_state(workflow_id, session_info["sid"]) + + def get_sid_mapping(self, sid: str) -> SidMapping | None: + raw = self._redis.get(self.sid_key(sid)) + if not raw: + return None + value = self._decode(raw) + if not value: + return None + try: + return json.loads(value) + except (TypeError, json.JSONDecodeError): + return None + + def delete_session(self, workflow_id: str, sid: str) -> None: + self._redis.hdel(self.workflow_key(workflow_id), sid) + self._redis.delete(self.sid_key(sid)) + + def session_exists(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.hexists(self.workflow_key(workflow_id), sid)) + + def sid_mapping_exists(self, sid: str) -> bool: + return bool(self._redis.exists(self.sid_key(sid))) + + def get_session_sids(self, workflow_id: str) -> list[str]: + raw_sids = self._redis.hkeys(self.workflow_key(workflow_id)) + decoded_sids: list[str] = [] + for sid in raw_sids: + decoded = self._decode(sid) + if decoded: + decoded_sids.append(decoded) + return decoded_sids + + def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + sessions_json = self._redis.hgetall(self.workflow_key(workflow_id)) + users: list[WorkflowSessionInfo] = [] + + for session_info_json in sessions_json.values(): + value = self._decode(session_info_json) + if not value: + continue + try: + session_info = json.loads(value) + except (TypeError, json.JSONDecodeError): + continue + + if not isinstance(session_info, dict): + continue + if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info: + continue + + users.append( + { + "user_id": str(session_info["user_id"]), + "username": str(session_info["username"]), + "avatar": session_info.get("avatar"), + "sid": str(session_info["sid"]), + "connected_at": int(session_info.get("connected_at") or 0), + } + ) + + return users + + def get_current_leader(self, workflow_id: str) -> str | None: + raw = self._redis.get(self.leader_key(workflow_id)) + return self._decode(raw) + + def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS)) + + def set_leader(self, workflow_id: str, sid: str) -> None: + self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) + + def delete_leader(self, workflow_id: str) -> None: + self._redis.delete(self.leader_key(workflow_id)) + + def expire_leader(self, workflow_id: str) -> None: + self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index c4c203c150..e242b0c667 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -57,6 +57,7 @@ def create_clusters(batch_size): cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), active=False, status=TidbAuthBindingStatus.CREATING, ) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 8479cdfb0c..2cc0192a4a 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -7,8 +7,8 @@ from sqlalchemy import select import app from configs import dify_config +from core.db.session_factory import session_factory from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service from models import Account, Tenant, TenantAccountJoin @@ -33,67 +33,68 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = db.session.scalars( - select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) - ).all() - # group by tenant_id - dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) - for dataset_auto_disable_log in dataset_auto_disable_logs: - if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - url = f"{dify_config.CONSOLE_WEB_URL}/datasets" - for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - features = FeatureService.get_features(tenant_id) - plan = features.billing.subscription.plan - if plan != CloudPlan.SANDBOX: - knowledge_details = [] - # check tenant - tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) - if not tenant: - continue - # check current owner - current_owner_join = db.session.scalar( - select(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") - .limit(1) - ) - if not current_owner_join: - continue - account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) - if not account: - continue + with session_factory.create_session() as session: + dataset_auto_disable_logs = session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False)) + ).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != CloudPlan.SANDBOX: + knowledge_details = [] + # check tenant + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + # check current owner + current_owner_join = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) + ) + if not current_owner_join: + continue + account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) + if not account: + continue - dataset_auto_dataset_map = {} # type: ignore + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id)) + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, + language_code="en-US", + to=account.email, + template_context={ + "userName": account.email, + "knowledge_details": knowledge_details, + "url": url, + }, + ) + + # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) - - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) - if dataset: - document_count = len(document_ids) - knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") - if knowledge_details: - email_service = get_email_i18n_service() - email_service.send_email( - email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, - language_code="en-US", - to=account.email, - template_context={ - "userName": account.email, - "knowledge_details": knowledge_details, - "url": url, - }, - ) - - # update notified to True - for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_disable_log.notified = True - db.session.commit() + dataset_auto_disable_log.notified = True + session.commit() end_at = time.perf_counter() logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/services/account_service.py b/api/services/account_service.py index ccc4a7c1fa..b6554a3de7 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -112,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: + # Phase-bound token metadata for the change-email flow. Tokens carry the + # current phase so that downstream endpoints can enforce proper progression + CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase" + CHANGE_EMAIL_PHASE_OLD = "old_email" + CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified" + CHANGE_EMAIL_PHASE_NEW = "new_email" + CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified" + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( @@ -576,13 +584,20 @@ class AccountService: raise ValueError("Email must be provided.") if not phase: raise ValueError("phase must be provided.") + if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW): + raise ValueError("phase must be one of old_email or new_email.") if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) - code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) + code, token = cls.generate_change_email_token( + account_email, + account, + old_email=old_email, + additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase}, + ) send_change_mail_task.delay( language=language, diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ff0882ad5c..0229a1f43a 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -133,7 +133,14 @@ class AppAnnotationService: raise ValueError("'question' is required when 'message_id' is not provided") question = maybe_question - annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id) + annotation = MessageAnnotation( + app_id=app.id, + conversation_id=None, + message_id=None, + content=answer, + question=question, + account_id=current_user.id, + ) db.session.add(annotation) db.session.commit() diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 87c28980ab..97aaea3395 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -10,12 +10,6 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel @@ -23,6 +17,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency from core.trigger.constants import ( @@ -35,6 +30,12 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType @@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" +CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION class Import(BaseModel): @@ -455,7 +456,7 @@ class AppDslService: app.updated_by = account.id self._session.add(app) - self._session.commit() + self._session.flush() app_was_created.send(app, account=account) # save dependencies diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 5e8c7aa337..d6c01e9dcc 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.db import session_factory -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +from services.quota_service import QuotaService, unlimited from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task @@ -106,7 +107,7 @@ class AppGenerateService: quota_charge = unlimited() if dify_config.BILLING_ENABLED: try: - quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id) except QuotaExceededError: raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") @@ -116,6 +117,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) + quota_charge.commit() effective_mode = ( AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode ) @@ -162,6 +164,7 @@ class AppGenerateService: invoke_from=invoke_from, streaming=True, call_depth=0, + workflow_run_id=str(uuid.uuid4()), ) payload_json = payload.model_dump_json() @@ -183,6 +186,10 @@ class AppGenerateService: else: # Blocking mode: run synchronously and return JSON instead of SSE # Keep behaviour consistent with WORKFLOW blocking branch. + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) advanced_generator = AdvancedChatAppGenerator() return rate_limit.generate( advanced_generator.convert_to_event_stream( @@ -194,6 +201,7 @@ class AppGenerateService: invoke_from=invoke_from, workflow_run_id=str(uuid.uuid4()), streaming=False, + pause_state_config=pause_config, ) ), request_id=request_id, diff --git a/api/services/app_service.py b/api/services/app_service.py index ef170c50ba..a046b909b3 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,8 +4,6 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from configs import dify_config @@ -17,6 +15,8 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account @@ -303,17 +303,22 @@ class AppService: return app - def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + def update_app_icon( + self, app: App, icon: str, icon_background: str, icon_type: IconType | str | None = None + ) -> App: """ Update app icon :param app: App instance :param icon: new icon :param icon_background: new icon_background + :param icon_type: new icon type :return: App instance """ assert current_user is not None app.icon = icon app.icon_background = icon_background + if icon_type is not None: + app.icon_type = icon_type if isinstance(icon_type, IconType) else IconType(icon_type) app.updated_by = current_user.id app.updated_at = naive_utc_now() db.session.commit() diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 0842e9d3e7..6e9d6b1c73 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,11 +5,10 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ -from graphon.graph_engine.manager import GraphEngineManager - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index a731d5c048..ceda30e950 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.quota_service import QuotaService, unlimited from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -88,7 +89,10 @@ class AsyncWorkflowService: raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}") # 2. Get workflow - workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id) + workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session) + + # commit read only session before starting the billig rpc call + session.commit() # 3. Get dispatcher based on tenant subscription dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id) @@ -131,9 +135,10 @@ class AsyncWorkflowService: trigger_log = trigger_log_repo.create(trigger_log) session.commit() - # 7. Check and consume quota + # 7. Reserve quota (commit after successful dispatch) + quota_charge = unlimited() try: - QuotaType.WORKFLOW.consume(trigger_data.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id) except QuotaExceededError as e: # Update trigger log status trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED @@ -153,13 +158,18 @@ class AsyncWorkflowService: # 9. Dispatch to appropriate queue task_data_dict = task_data.model_dump(mode="json") - task: AsyncResult[Any] | None = None - if queue_name == QueuePriority.PROFESSIONAL: - task = execute_workflow_professional.delay(task_data_dict) - elif queue_name == QueuePriority.TEAM: - task = execute_workflow_team.delay(task_data_dict) - else: # SANDBOX - task = execute_workflow_sandbox.delay(task_data_dict) + try: + task: AsyncResult[Any] | None = None + if queue_name == QueuePriority.PROFESSIONAL: + task = execute_workflow_professional.delay(task_data_dict) + elif queue_name == QueuePriority.TEAM: + task = execute_workflow_team.delay(task_data_dict) + else: # SANDBOX + task = execute_workflow_sandbox.delay(task_data_dict) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise # 10. Update trigger log with task info trigger_log.status = WorkflowTriggerStatus.QUEUED @@ -295,13 +305,21 @@ class AsyncWorkflowService: return [log.to_dict() for log in logs] @staticmethod - def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow: + def _get_workflow( + workflow_service: WorkflowService, + app_model: App, + workflow_id: str | None = None, + session: Session | None = None, + ) -> Workflow: """ Get workflow for the app Args: app_model: App model instance workflow_id: Optional specific workflow ID + session: Reuse this SQLAlchemy session for the lookup when provided, + so the caller's explicit session bears the connection cost + instead of Flask's request-scoped ``db.session``. Returns: Workflow instance @@ -311,12 +329,12 @@ class AsyncWorkflowService: """ if workflow_id: # Get specific published workflow - workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id) + workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session) if not workflow: raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}") else: # Get default published workflow - workflow = workflow_service.get_published_workflow(app_model) + workflow = workflow_service.get_published_workflow(app_model, session=session) if not workflow: raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}") diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1c7027efb4..60948e652b 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context -from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/billing_service.py b/api/services/billing_service.py index a1362ccad6..c0e23cdc6f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class QuotaReserveResult(TypedDict): + reservation_id: str + available: int + reserved: int + + +class QuotaCommitResult(TypedDict): + available: int + reserved: int + refunded: int + + +class QuotaReleaseResult(TypedDict): + available: int + reserved: int + released: int + + +_quota_reserve_adapter = TypeAdapter(QuotaReserveResult) +_quota_commit_adapter = TypeAdapter(QuotaCommitResult) +_quota_release_adapter = TypeAdapter(QuotaReleaseResult) + + +class _TenantFeatureQuota(TypedDict): + usage: int + limit: int + reset_date: NotRequired[int] + + +class TenantFeatureQuotaInfo(TypedDict): + """Response of /quota/info. + + NOTE (hj24): + - Same convention as BillingInfo: billing may return int fields as str, + always keep non-strict mode to auto-coerce. + """ + + trigger_event: _TenantFeatureQuota + api_rate_limit: _TenantFeatureQuota + + +_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo) + + class _BillingQuota(TypedDict): size: int limit: int @@ -149,11 +193,63 @@ class BillingService: @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): + """Deprecated: Use get_quota_info instead.""" params = {"tenant_id": tenant_id} - usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info + @classmethod + def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo: + params = {"tenant_id": tenant_id} + return _tenant_feature_quota_info_adapter.validate_python( + cls._send_request("GET", "/quota/info", params=params) + ) + + @classmethod + def quota_reserve( + cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None + ) -> QuotaReserveResult: + """Reserve quota before task execution.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "request_id": request_id, + "amount": amount, + } + if meta: + payload["meta"] = meta + return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload)) + + @classmethod + def quota_commit( + cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None + ) -> QuotaCommitResult: + """Commit a reservation with actual consumption.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + "actual_amount": actual_amount, + } + if meta: + payload["meta"] = meta + return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload)) + + @classmethod + def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult: + """Release a reservation (cancel, return frozen quota).""" + return _quota_release_adapter.validate_python( + cls._send_request( + "POST", + "/quota/release", + json={ + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + }, + ) + ) + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index ea12e40420..dcc93b4b0f 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,7 +6,6 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, sessionmaker @@ -14,6 +13,7 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index f5085af59b..ee8a1c4edd 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,7 +3,6 @@ import logging from collections.abc import Callable, Sequence from typing import Any -from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -13,6 +12,7 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 95a8951951..287d513f48 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ -from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 6c6de192c6..eef38f1ce2 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,9 +10,6 @@ from collections.abc import Sequence from typing import Any, Literal, TypedDict, cast import sqlalchemy as sa -from graphon.file import helpers as file_helpers -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError from sqlalchemy import delete, exists, func, select, update from sqlalchemy.orm import Session, sessionmaker @@ -31,6 +28,9 @@ from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -3748,6 +3748,7 @@ class SegmentService: ChildChunk.segment_id == segment.id, ) ) + assert current_user.current_tenant_id child_chunk = ChildChunk( tenant_id=current_user.current_tenant_id, dataset_id=dataset.id, @@ -3758,7 +3759,7 @@ class SegmentService: index_node_hash=index_node_hash, content=content, word_count=len(content), - type="customized", + type=SegmentType.CUSTOMIZED, created_by=current_user.id, ) db.session.add(child_chunk) @@ -3818,6 +3819,7 @@ class SegmentService: if new_child_chunks_args: child_chunk_count = len(child_chunks) for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): + assert current_user.current_tenant_id index_node_id = str(uuid.uuid4()) index_node_hash = helper.generate_text_hash(args.content) child_chunk = ChildChunk( @@ -3830,7 +3832,7 @@ class SegmentService: index_node_hash=index_node_hash, content=args.content, word_count=len(args.content), - type="customized", + type=SegmentType.CUSTOMIZED, created_by=current_user.id, ) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 364c4a86a0..416bc8cef9 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,7 +3,6 @@ import time from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy import delete, func, select, update from sqlalchemy.orm import Session, sessionmaker @@ -18,6 +17,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 5040fcc7e3..bd7758f1c0 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime from typing import TYPE_CHECKING +from cachetools.func import ttl_cache from pydantic import BaseModel, ConfigDict, Field, model_validator from configs import dify_config @@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None: class EnterpriseService: @classmethod + @ttl_cache(ttl=5) def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index f6a670415b..b1fe352861 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -87,7 +87,7 @@ class RetrievalModel(BaseModel): class MetaDataConfig(BaseModel): doc_type: str - doc_metadata: dict + doc_metadata: dict[str, Any] class KnowledgeConfig(BaseModel): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index a944ef6acd..6679c08ebd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,15 +1,6 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -24,6 +15,15 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 6dcedfdced..60b457ecd0 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -4,13 +4,13 @@ from typing import Any, cast from urllib.parse import urlparse import httpx -from graphon.nodes.http_request.exc import InvalidHttpMethodError from sqlalchemy import func, select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities import MetadataFilteringCondition from extensions.ext_database import db +from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, diff --git a/api/services/feature_service.py b/api/services/feature_service.py index df653e0ba7..9477c28bf3 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -3,6 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict, Field from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService @@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel): class SystemFeatureModel(BaseModel): + app_dsl_version: str = "" sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" enable_marketplace: bool = False @@ -164,6 +166,7 @@ class SystemFeatureModel(BaseModel): enable_email_code_login: bool = False enable_email_password_login: bool = True enable_social_oauth_login: bool = False + enable_collaboration_mode: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False is_email_setup: bool = False @@ -174,6 +177,7 @@ class SystemFeatureModel(BaseModel): enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() trial_models: list[str] = [] + enable_creators_platform: bool = False enable_trial_app: bool = False enable_explore_banner: bool = False @@ -224,6 +228,7 @@ class FeatureService: @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() + system_features.app_dsl_version = CURRENT_APP_DSL_VERSION cls._fulfill_system_params_from_env(system_features) @@ -237,6 +242,9 @@ class FeatureService: if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True + if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + system_features.enable_creators_platform = True + return system_features @classmethod @@ -244,6 +252,7 @@ class FeatureService: system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" @@ -281,7 +290,7 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + features_usage_info = BillingService.get_quota_info(tenant_id) features.billing.enabled = billing_info["enabled"] features.billing.subscription.plan = billing_info["subscription"]["plan"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 79a935de4b..f60afe2f19 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -2,13 +2,12 @@ import base64 import hashlib import os import uuid -from collections.abc import Iterator, Sequence +from collections.abc import Generator, Sequence # Changed Iterator to Generator from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile -from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -24,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account @@ -324,7 +324,7 @@ class FileService: def build_upload_files_zip_tempfile( *, upload_files: Sequence[UploadFile], - ) -> Iterator[str]: + ) -> Generator[str, None, None]: # Changed from Iterator[str] """ Build a ZIP from `UploadFile`s and yield a tempfile path. diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 43985e49cd..2e5987dd28 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,17 +1,16 @@ import json import logging import time -from typing import Any, TypedDict - -from graphon.model_runtime.entities import LLMMode +from typing import Any, TypedDict, cast from core.app.app_config.entities import ModelConfig -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db +from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource @@ -37,6 +36,10 @@ default_retrieval_model = { } +class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False): + metadata_filtering_conditions: dict[str, Any] + + class HitTestingService: @classmethod def retrieve( @@ -52,17 +55,18 @@ class HitTestingService: start = time.perf_counter() # get retrieval model , if the model is not setting , using default - if not retrieval_model: - retrieval_model = dataset.retrieval_model or default_retrieval_model - assert isinstance(retrieval_model, dict) + resolved_retrieval_model = cast( + HitTestingRetrievalModelDict, + retrieval_model or dataset.retrieval_model or default_retrieval_model, + ) document_ids_filter = None - metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions and query: + metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {}) + if metadata_filtering_conditions_raw and query: dataset_retrieval = DatasetRetrieval() from core.rag.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], @@ -79,19 +83,21 @@ class HitTestingService: if metadata_condition and not document_ids_filter: return cls.compact_retrieve_response(query, []) all_documents = RetrievalService.retrieve( - retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), + retrieval_method=RetrievalMethod( + resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH) + ), dataset_id=dataset.id, query=query, attachment_ids=attachment_ids, - top_k=retrieval_model.get("top_k", 4), - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] + top_k=resolved_retrieval_model.get("top_k", 4), + score_threshold=resolved_retrieval_model.get("score_threshold", 0.0) + if resolved_retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] + reranking_model=resolved_retrieval_model.get("reranking_model", None) + if resolved_retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model", + weights=resolved_retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, ) diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 77576fa4c0..8b4983e5f7 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,12 +4,11 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol -from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -18,6 +17,7 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 02a6620fc7..76598d31ac 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,12 +3,6 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -17,6 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/message_service.py b/api/services/message_service.py index 5b133b0c04..8f5e028d4d 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,6 +1,7 @@ +import logging from collections.abc import Sequence +from typing import cast -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -14,10 +15,20 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating -from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback +from models.model import ( + App, + AppMode, + AppModelConfig, + AppModelConfigDict, + EndUser, + Message, + MessageFeedback, + SuggestedQuestionsAfterAnswerConfig, +) from repositories.execution_extra_content_repository import ExecutionExtraContentRepository from repositories.sqlalchemy_execution_extra_content_repository import ( SQLAlchemyExecutionExtraContentRepository, @@ -32,6 +43,7 @@ from services.errors.message import ( from services.workflow_service import WorkflowService _app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict) +logger = logging.getLogger(__name__) def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository: @@ -252,6 +264,7 @@ class MessageService: ) model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) + suggested_questions_after_answer_config: SuggestedQuestionsAfterAnswerConfig = {"enabled": False} if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() @@ -271,9 +284,11 @@ class MessageService: if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() - model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, model_type=ModelType.LLM - ) + suggested_questions_after_answer = workflow.features_dict.get("suggested_questions_after_answer") + if isinstance(suggested_questions_after_answer, dict): + suggested_questions_after_answer_config = cast( + SuggestedQuestionsAfterAnswerConfig, suggested_questions_after_answer + ) else: if not conversation.override_model_configs: app_model_config = db.session.scalar( @@ -293,16 +308,14 @@ class MessageService: if not app_model_config: raise ValueError("did not find app model config") - suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict - if suggested_questions_after_answer.get("enabled", False) is False: + suggested_questions_after_answer_config = app_model_config.suggested_questions_after_answer_dict + if suggested_questions_after_answer_config.get("enabled", False) is False: raise SuggestedQuestionsAfterAnswerDisabledError() - model_instance = model_manager.get_model_instance( - tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict["provider"], - model_type=ModelType.LLM, - model=app_model_config.model_dict["name"], - ) + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, + model_type=ModelType.LLM, + ) # get memory of conversation (read-only) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) @@ -312,9 +325,17 @@ class MessageService: message_limit=3, ) + instruction_prompt = suggested_questions_after_answer_config.get("prompt") + if not isinstance(instruction_prompt, str) or not instruction_prompt.strip(): + instruction_prompt = None + + configured_model = suggested_questions_after_answer_config.get("model") with measure_time() as timer: questions_sequence = LLMGenerator.generate_suggested_questions_after_answer( - tenant_id=app_model.tenant_id, histories=histories + tenant_id=app_model.tenant_id, + histories=histories, + instruction_prompt=instruction_prompt, + model_config=configured_model, ) questions: list[str] = list(questions_sequence) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index b652e049ce..c269346f5f 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,12 +2,6 @@ import json import logging from typing import Any, TypedDict -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -18,6 +12,12 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 3f37c9b176..51cda79661 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,10 +1,10 @@ import logging - -from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule +from typing import Any from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -168,7 +168,9 @@ class ModelProviderService: model_name=model, ) - def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: + def get_provider_credential( + self, tenant_id: str, provider: str, credential_id: str | None = None + ) -> dict[str, Any] | None: """ get provider credentials. @@ -180,7 +182,7 @@ class ModelProviderService: provider_configuration = self._get_provider_configuration(tenant_id, provider) return provider_configuration.get_provider_credential(credential_id=credential_id) - def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]): """ validate provider credentials before saving. @@ -192,7 +194,7 @@ class ModelProviderService: provider_configuration.validate_provider_credentials(credentials) def create_provider_credential( - self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None + self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None ) -> None: """ Create and save new provider credentials. @@ -210,7 +212,7 @@ class ModelProviderService: self, tenant_id: str, provider: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: @@ -254,7 +256,7 @@ class ModelProviderService: def get_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None - ) -> dict | None: + ) -> dict[str, Any] | None: """ Retrieve model-specific credentials. @@ -270,7 +272,9 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict): + def validate_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any] + ): """ validate model credentials. @@ -287,7 +291,13 @@ class ModelProviderService: ) def create_model_credential( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + credentials: dict[str, Any], + credential_name: str | None, ) -> None: """ create and save model credentials. @@ -314,7 +324,7 @@ class ModelProviderService: provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index 9bb0ab6ae2..b96b8140ac 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -23,7 +23,7 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -50,7 +50,7 @@ class PluginAutoUpgradeService: @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) diff --git a/api/services/quota_service.py b/api/services/quota_service.py new file mode 100644 index 0000000000..4c784315c7 --- /dev/null +++ b/api/services/quota_service.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from enums.quota_type import QuotaType + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaCharge: + """ + Result of a quota reservation (Reserve phase). + + Lifecycle: + charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id) + try: + do_work() + charge.commit() # Confirm consumption + except: + charge.refund() # Release frozen quota + + If neither commit() nor refund() is called, the billing system's + cleanup CronJob will auto-release the reservation within ~75 seconds. + """ + + success: bool + charge_id: str | None # reservation_id + _quota_type: QuotaType + _tenant_id: str | None = None + _feature_key: str | None = None + _amount: int = 0 + _committed: bool = field(default=False, repr=False) + + def commit(self, actual_amount: int | None = None) -> None: + """ + Confirm the consumption with actual amount. + + Args: + actual_amount: Actual amount consumed. Defaults to the reserved amount. + If less than reserved, the difference is refunded automatically. + """ + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: + return + + try: + from services.billing_service import BillingService + + amount = actual_amount if actual_amount is not None else self._amount + BillingService.quota_commit( + tenant_id=self._tenant_id, + feature_key=self._feature_key, + reservation_id=self.charge_id, + actual_amount=amount, + ) + self._committed = True + logger.debug( + "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", + self._quota_type, + self._tenant_id, + self.charge_id, + amount, + ) + except Exception: + logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) + + def refund(self) -> None: + """ + Release the reserved quota (cancel the charge). + + Safe to call even if: + - charge failed or was disabled (charge_id is None) + - already committed (Release after Commit is a no-op) + - already refunded (idempotent) + + This method guarantees no exceptions will be raised. + """ + if not self.charge_id or not self._tenant_id or not self._feature_key: + return + + QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key) + + +def unlimited() -> QuotaCharge: + from enums.quota_type import QuotaType + + return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) + + +class QuotaService: + """Orchestrates quota reserve / commit / release lifecycle via BillingService.""" + + @staticmethod + def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve + immediate Commit (one-shot mode). + + The returned QuotaCharge supports .refund() which calls Release. + For two-phase usage (e.g. streaming), use reserve() directly. + """ + charge = QuotaService.reserve(quota_type, tenant_id, amount) + if charge.success and charge.charge_id: + charge.commit() + return charge + + @staticmethod + def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve quota before task execution (Reserve phase only). + + The caller MUST call charge.commit() after the task succeeds, + or charge.refund() if the task fails. + + Raises: + QuotaExceededError: When quota is insufficient + """ + from services.billing_service import BillingService + from services.errors.app import QuotaExceededError + + if not dify_config.BILLING_ENABLED: + logger.debug("Billing disabled, allowing request for %s", tenant_id) + return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type) + + logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id) + + if amount <= 0: + raise ValueError("Amount to reserve must be greater than 0") + + request_id = str(uuid.uuid4()) + feature_key = quota_type.billing_key + + try: + reserve_resp = BillingService.quota_reserve( + tenant_id=tenant_id, + feature_key=feature_key, + request_id=request_id, + amount=amount, + ) + + reservation_id = reserve_resp.get("reservation_id") + if not reservation_id: + logger.warning( + "Reserve returned no reservation_id for %s, feature %s, response: %s", + tenant_id, + quota_type.value, + reserve_resp, + ) + raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount) + + logger.debug( + "Reserved %d %s quota for tenant %s, reservation_id: %s", + amount, + quota_type.value, + tenant_id, + reservation_id, + ) + return QuotaCharge( + success=True, + charge_id=reservation_id, + _quota_type=quota_type, + _tenant_id=tenant_id, + _feature_key=feature_key, + _amount=amount, + ) + + except QuotaExceededError: + raise + except ValueError: + raise + except Exception: + logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value) + return unlimited() + + @staticmethod + def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool: + if not dify_config.BILLING_ENABLED: + return True + + if amount <= 0: + raise ValueError("Amount to check must be greater than 0") + + try: + remaining = QuotaService.get_remaining(quota_type, tenant_id) + return remaining >= amount if remaining != -1 else True + except Exception: + logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value) + return True + + @staticmethod + def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None: + """Release a reservation. Guarantees no exceptions.""" + try: + from services.billing_service import BillingService + + if not dify_config.BILLING_ENABLED: + return + + if not reservation_id: + return + + logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id) + BillingService.quota_release( + tenant_id=tenant_id, + feature_key=feature_key, + reservation_id=reservation_id, + ) + except Exception: + logger.exception("Failed to release quota, reservation_id: %s", reservation_id) + + @staticmethod + def get_remaining(quota_type: QuotaType, tenant_id: str) -> int: + from services.billing_service import BillingService + + try: + usage_info = BillingService.get_quota_info(tenant_id) + if isinstance(usage_info, dict): + feature_info = usage_info.get(quota_type.billing_key, {}) + if isinstance(feature_info, dict): + limit = feature_info.get("limit", 0) + usage = feature_info.get("usage", 0) + if limit == -1: + return -1 + return max(0, limit - usage) + return 0 + except Exception: + logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value) + return -1 diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index aa7456dcd3..8c9a81af87 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod @@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param template_id: Template ID :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 0ffbef8365..9d446f6d4b 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class CustomizedTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + position: int + chunk_structure: str + + +class CustomizedTemplatesResultDict(TypedDict): + pipeline_templates: list[CustomizedTemplateItemDict] + + +class CustomizedTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + created_by: str + + class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval recommended app from database @@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): def get_pipeline_templates(self, language: str) -> dict[str, Any]: _, current_tenant_id = current_account_with_tenant() - result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - return result + return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED @@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) ).all() - recommended_pipelines_results = [] + recommended_pipelines_results: list[CustomizedTemplateItemDict] = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { + recommended_pipeline_result: CustomizedTemplateItemDict = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, "description": pipeline_customized_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 073eed221c..2964537c35 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class PipelineTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + copyright: str + privacy_policy: str + position: int + chunk_structure: str + + +class PipelineTemplatesResultDict(TypedDict): + pipeline_templates: list[PipelineTemplateItemDict] + + +class PipelineTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + + class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ def get_pipeline_templates(self, language: str) -> dict[str, Any]: - result = self.fetch_pipeline_templates_from_db(language) - return result + return self.fetch_pipeline_templates_from_db(language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.DATABASE @@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ).all() ) - recommended_pipelines_results = [] + recommended_pipelines_results: list[PipelineTemplateItemDict] = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { + recommended_pipeline_result: PipelineTemplateItemDict = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, "description": pipeline_built_in_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index d5ef745bec..9565ac46cc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result: dict[str, Any] | None try: - result = self.fetch_pipeline_template_detail_from_dify_official(template_id) + return self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) def get_pipeline_templates(self, language: str) -> dict[str, Any]: try: - result = self.fetch_pipeline_templates_from_dify_official(language) + return self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) def get_type(self) -> str: return PipelineTemplateType.REMOTE diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 5fc5b412b3..9db6682e10 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,15 +9,6 @@ from typing import Any, cast from uuid import uuid4 from flask_login import current_user -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -53,6 +44,15 @@ from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore @@ -104,7 +104,7 @@ class RagPipelineService: self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @classmethod - def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: + def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -120,7 +120,7 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict[str, Any] | None: """ Get pipeline template detail. @@ -131,7 +131,7 @@ class RagPipelineService: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + built_in_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) if built_in_result is None: logger.warning( "pipeline template retrieval returned empty result, template_id: %s, mode: %s", @@ -142,7 +142,7 @@ class RagPipelineService: else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + customized_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) return customized_result @classmethod @@ -297,7 +297,7 @@ class RagPipelineService: self, *, pipeline: Pipeline, - graph: dict, + graph: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -467,14 +467,16 @@ class RagPipelineService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: + def get_default_block_config( + self, node_type: str, filters: dict[str, Any] | None = None + ) -> Mapping[str, object] | None: """ Get default config of node. :param node_type: node type :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -500,7 +502,7 @@ class RagPipelineService: return default_config def run_draft_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + self, pipeline: Pipeline, node_id: str, user_inputs: dict[str, Any], account: Account ) -> WorkflowNodeExecutionModel | None: """ Run draft workflow node @@ -582,7 +584,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -749,7 +751,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -979,7 +981,7 @@ class RagPipelineService: return workflow_node_execution def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes @@ -1099,7 +1101,9 @@ class RagPipelineService: ] return datasource_provider_variables - def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: + def get_rag_pipeline_paginate_workflow_runs( + self, pipeline: Pipeline, args: dict[str, Any] + ) -> InfiniteScrollPagination: """ Get debug workflow run list Only return triggered_from == debugging @@ -1169,7 +1173,7 @@ class RagPipelineService: return list(node_executions) @classmethod - def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): + def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]): """ Publish customized pipeline template """ @@ -1259,7 +1263,7 @@ class RagPipelineService: ) return node_exec - def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account): + def set_datasource_variables(self, pipeline: Pipeline, args: dict[str, Any], current_user: Account): """ Set datasource variables """ @@ -1346,7 +1350,7 @@ class RagPipelineService: ) return workflow_node_execution_db_model - def get_recommended_plugins(self, type: str) -> dict: + def get_recommended_plugins(self, type: str) -> dict[str, Any]: # Query active recommended plugins stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) if type and type != "all": diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 7dd86f1581..f315d053cb 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -13,12 +13,6 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel from sqlalchemy import select @@ -33,6 +27,12 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index ab60986bfe..21be411bea 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,13 +27,13 @@ from dataclasses import dataclass, field from typing import Any, TypedDict import click -from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index c906e3bca3..cf39469be8 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -6,8 +6,6 @@ import uuid from datetime import UTC, datetime from typing import TypedDict, cast -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,6 +16,8 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -349,7 +349,6 @@ class SummaryIndexService: summary_record_id, ) summary_record_in_session = DocumentSegmentSummary( - id=summary_record_id, # Use the same ID if available dataset_id=dataset.id, document_id=segment.document_id, chunk_id=segment.id, @@ -360,6 +359,9 @@ class SummaryIndexService: status=SummaryStatus.COMPLETED, enabled=True, ) + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + summary_record_in_session.id = summary_record_id session.add(summary_record_in_session) logger.info( "Created new summary record (id=%s) for segment %s after vectorization", diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 3bfa221528..5ff2c21749 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,9 +2,9 @@ import json import logging from typing import Any, TypedDict, cast -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select +from sqlalchemy.orm import sessionmaker from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -16,11 +16,13 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ) +from core.tools.errors import ApiToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -116,71 +118,85 @@ class ApiToolManageService: privacy_policy: str, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - create api tool provider + Create a new API tool provider. + + :param user_id: The ID of the user creating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # Create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is not None: - raise ValueError(f"provider {provider_name} already exists") + if provider is not None: + raise ValueError(f"provider {provider_name} already exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if len(tool_bundles) > 100: - raise ValueError("the number of apis should be less than 100") + if len(tool_bundles) > 100: + raise ValueError("the number of apis should be less than 100") - # create db provider - db_provider = ApiToolProvider( - tenant_id=tenant_id, - user_id=user_id, - name=provider_name, - icon=json.dumps(icon), - schema=schema, - description=extra_info.get("description", ""), - schema_type_str=schema_type, - tools_str=json.dumps(jsonable_encoder(tool_bundles)), - credentials_str="{}", - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - ) + # create API tool provider + api_tool_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get("description", ""), + schema_type_str=schema_type, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str="{}", + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + ) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # create provider entity + provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - # encrypt credentials - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) - db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) + # encrypt credentials + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) + api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) - db.session.add(db_provider) - db.session.commit() + _session.add(api_tool_provider) - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) return {"result": "success"} @@ -212,16 +228,25 @@ class ApiToolManageService: @staticmethod def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]: """ - list api tool provider tools + List tools provided by a specific API tool provider. + + :param user_id: The ID of the user requesting the list. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :return: A list of ToolApiEntity objects. """ - provider: ApiToolProvider | None = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) if provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -251,103 +276,133 @@ class ApiToolManageService: privacy_policy: str | None, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - update api tool provider + Update an existing API tool provider. + + :param user_id: The ID of the user updating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The new name of the API tool provider. + :param original_provider: The original name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param _schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + if provider is None: + raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id) - # update db provider - provider.name = provider_name - provider.icon = json.dumps(icon) - provider.schema = schema - provider.description = extra_info.get("description", "") - provider.schema_type_str = schema_type - provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) - provider.privacy_policy = privacy_policy - provider.custom_disclaimer = custom_disclaimer + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + # update db provider + provider.name = provider_name + provider.icon = json.dumps(icon) + provider.schema = schema + provider.description = extra_info.get("description", "") + provider.schema_type_str = schema_type + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) + provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # create provider entity - provider_controller = ApiToolProviderController.from_db(provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # get original credentials if exists - encrypter, cache = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) + # create provider entity + provider_controller = ApiToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - original_credentials = encrypter.decrypt(provider.credentials) - masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] + # get original credentials if exists + encrypter, cache = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) - credentials = dict(encrypter.encrypt(credentials)) - provider.credentials_str = json.dumps(credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - db.session.add(provider) - db.session.commit() + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + + credentials = dict(encrypter.encrypt(credentials)) + provider.credentials_str = json.dumps(credentials) + + _session.add(provider) + + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) # delete cache cache.delete() - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) - return {"result": "success"} @staticmethod def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + Delete an API tool provider. + + :param user_id: The ID of the user performing the deletion operation. + :param tenant_id: The ID of the workspace/tenant where the provider belongs. + :param provider_name: The unique name of the API tool provider to be deleted. + :raises ValueError: If the specified provider does not exist in the tenant. + :return: A dictionary indicating the result status. """ - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider) - db.session.commit() + _session.delete(provider) return {"result": "success"} @staticmethod - def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]: """ - get api tool provider + Get API tool provider details. + + :param user_id: The ID of the user requesting the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider: The name of the API tool provider. + :return: A dictionary containing the provider details. """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) @@ -360,10 +415,20 @@ class ApiToolManageService: parameters: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, - ): + ) -> dict[str, Any]: """ - test api tool before adding api tool provider + Test an API tool before adding the API tool provider. + + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param tool_name: The name of the specific tool to test. + :param credentials: The credentials for the provider. + :param parameters: The parameters to pass to the tool. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :return: A dictionary containing the result or error message. """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema_type}") @@ -377,18 +442,21 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # create new session with automatic transaction management to get the provider + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if not db_provider: + if provider is None: # create a fake db provider - db_provider = ApiToolProvider( + provider = ApiToolProvider( tenant_id="", user_id="", name="", @@ -407,12 +475,12 @@ class ApiToolManageService: auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) # decrypt credentials - if db_provider.id: + if provider.id: encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, controller=provider_controller, @@ -443,14 +511,21 @@ class ApiToolManageService: @staticmethod def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ - list api tools + List all API tools for a specific tenant. + + :param tenant_id: The ID of the workspace/tenant. + :return: A list of ToolProviderApiEntity objects. """ # get all api providers - db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + # create new session with automatic transaction management + providers: list[ApiToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + ) result: list[ToolProviderApiEntity] = [] - - for provider in db_providers: + for provider in providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 7bd056b8a0..b8242ab3a5 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_provider_encrypter -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import ToolProviderID @@ -521,7 +521,7 @@ class BuiltinToolManageService: ) if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 779f7c4511..8f6600af03 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -3,7 +3,6 @@ import logging from datetime import datetime from typing import Any -from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import delete, or_, select from sqlalchemy.orm import sessionmaker @@ -15,6 +14,7 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow @@ -139,62 +139,82 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - # check if the name is unique - existing_workflow_tool_provider = db.session.scalar( - select(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id, - ) - .limit(1) - ) + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name exists for other tools + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .limit(1) + ) + + # if the name exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) + # query the workflow tool provider + workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + # if not found raise error if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App | None = db.session.scalar( - select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1) - ) + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar( + select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1) + ) + # if not found raise error if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider.name = name - workflow_tool_provider.label = label - workflow_tool_provider.icon = json.dumps(icon) - workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) - workflow_tool_provider.privacy_policy = privacy_policy - workflow_tool_provider.version = workflow.version - workflow_tool_provider.updated_at = datetime.now() + with sessionmaker(db.engine).begin() as _session: + _session.add(workflow_tool_provider) - try: - WorkflowToolProviderController.from_db(workflow_tool_provider) - except Exception as e: - raise ValueError(str(e)) + # update workflow tool provider + workflow_tool_provider.name = name + workflow_tool_provider.label = label + workflow_tool_provider.icon = json.dumps(icon) + workflow_tool_provider.description = description + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) + workflow_tool_provider.privacy_policy = privacy_policy + workflow_tool_provider.version = workflow.version + workflow_tool_provider.updated_at = datetime.now() - db.session.commit() + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) - if labels is not None: - ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels - ) + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), + labels, + session=_session, + ) return {"result": "success"} diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 25e80770b8..a827222c1d 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,7 +2,6 @@ import json import logging from datetime import datetime -from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,6 +13,7 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError +from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 6e14d996ea..b8a76e4945 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, @@ -635,7 +635,7 @@ class TriggerProviderService: if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 5a5d13b96d..911331e357 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,7 +5,6 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -21,6 +20,7 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index bb767a6759..5d99900a04 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -7,9 +7,6 @@ from typing import Any, NotRequired, TypedDict import orjson from flask import request -from graphon.entities.graph_config import NodeConfigDict -from graphon.file import FileTransferMethod -from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -31,6 +28,9 @@ from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger @@ -38,6 +38,7 @@ from models.workflow import Workflow from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData @@ -798,45 +799,47 @@ class WebhookService: Exception: If workflow execution fails """ try: - with Session(db.engine) as session: - # Prepare inputs for the webhook node - # The webhook node expects webhook_data in the inputs - workflow_inputs = cls.build_workflow_inputs(webhook_data) + workflow_inputs = cls.build_workflow_inputs(webhook_data) - # Create trigger data - trigger_data = WebhookTriggerData( - app_id=webhook_trigger.app_id, - workflow_id=workflow.id, - root_node_id=webhook_trigger.node_id, # Start from the webhook node - inputs=workflow_inputs, - tenant_id=webhook_trigger.tenant_id, + trigger_data = WebhookTriggerData( + app_id=webhook_trigger.app_id, + workflow_id=workflow.id, + root_node_id=webhook_trigger.node_id, + inputs=workflow_inputs, + tenant_id=webhook_trigger.tenant_id, + ) + + end_user = EndUserService.get_or_create_end_user_by_type( + type=InvokeFrom.TRIGGER, + tenant_id=webhook_trigger.tenant_id, + app_id=webhook_trigger.app_id, + user_id=None, + ) + + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) + logger.info( + "Tenant %s rate limited, skipping webhook trigger %s", + webhook_trigger.tenant_id, + webhook_trigger.webhook_id, ) + raise - end_user = EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.TRIGGER, - tenant_id=webhook_trigger.tenant_id, - app_id=webhook_trigger.app_id, - user_id=None, - ) - - # consume quota before triggering workflow execution - try: - QuotaType.TRIGGER.consume(webhook_trigger.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) - logger.info( - "Tenant %s rate limited, skipping webhook trigger %s", - webhook_trigger.tenant_id, - webhook_trigger.webhook_id, + try: + # NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()` + # trigger_workflow_async need to handle multipe session commits internally + with Session(db.engine, expire_on_commit=False) as session: + AsyncWorkflowService.trigger_workflow_async( + session, + end_user, + trigger_data, ) - raise - - # Trigger workflow execution asynchronously - AsyncWorkflowService.trigger_workflow_async( - session, - end_user, - trigger_data, - ) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise except Exception: logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 4d58a9cf12..1529c2b98f 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, overload +from configs import dify_config from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( @@ -21,8 +22,6 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments -from configs import dify_config - _MAX_DEPTH = 100 @@ -170,7 +169,7 @@ class VariableTruncator(BaseTruncator): return TruncationResult(StringSegment(value=fallback_result.value), True) # Apply final fallback - convert to JSON string and truncate - json_str = dumps_with_segments(result.value, ensure_ascii=False) + json_str = dumps_with_segments(result.value) if len(json_str) > self._max_size_bytes: json_str = json_str[: self._max_size_bytes] + "..." return TruncationResult(result=StringSegment(value=json_str), truncated=True) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 9827c8dfbc..7e689af35d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,6 +1,5 @@ import logging -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, select from core.model_manager import ModelInstance, ModelManager @@ -13,9 +12,11 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.enums import SegmentType logger = logging.getLogger(__name__) @@ -178,7 +179,7 @@ class VectorService: index_node_hash=child_chunk.metadata["doc_hash"], content=child_chunk.page_content, word_count=len(child_chunk.page_content), - type="automatic", + type=SegmentType.AUTOMATIC, created_by=dataset_document.created_by, ) db.session.add(child_segment) @@ -222,6 +223,7 @@ class VectorService: ) documents.append(new_child_document) for update_child_chunk in update_child_chunks: + assert update_child_chunk.index_node_id child_document = Document( page_content=update_child_chunk.content, metadata={ @@ -234,6 +236,7 @@ class VectorService: documents.append(child_document) delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: + assert delete_child_chunk.index_node_id delete_node_ids.append(delete_child_chunk.index_node_id) if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index @@ -246,6 +249,7 @@ class VectorService: @classmethod def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): vector = Vector(dataset=dataset) + assert child_chunk.index_node_id vector.delete_by_ids([child_chunk.index_node_id]) @classmethod diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1582bcd46c..5dedb9e372 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,11 +1,6 @@ import json from typing import Any, TypedDict -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from sqlalchemy import select from core.app.app_config.entities import ( @@ -24,6 +19,11 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index b5ab176ad2..59e02ec9b9 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,10 +3,10 @@ import uuid from datetime import datetime from typing import Any, TypedDict -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_collaboration_service.py b/api/services/workflow_collaboration_service.py new file mode 100644 index 0000000000..cf2f509052 --- /dev/null +++ b/api/services/workflow_collaboration_service.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import logging +import time +from collections.abc import Mapping + +from sqlalchemy import select + +from core.db.session_factory import session_factory +from models.account import Account +from models.model import App +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo + +logger = logging.getLogger(__name__) + + +class WorkflowCollaborationService: + def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None: + self._repository = repository + self._socketio = socketio + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(repository={self._repository})" + + def save_socket_identity(self, sid: str, user: Account) -> None: + """Persist the authenticated console user on the raw socket session.""" + self._socketio.save_session( + sid, + { + "user_id": user.id, + "username": user.name, + "avatar": user.avatar, + "tenant_id": user.current_tenant_id, + }, + ) + + def authorize_and_join_workflow_room(self, workflow_id: str, sid: str) -> tuple[str, bool] | None: + """ + Join a collaboration room only after validating the socket session and tenant-scoped app access. + + The Socket.IO payload still calls the room key `workflow_id`, but the identifier is the workflow app's + `App.id`. Returning `None` lets the controller reject the join before any Redis or room state is created. + """ + session = self._socketio.get_session(sid) + user_id = session.get("user_id") + tenant_id = session.get("tenant_id") + if not user_id or not tenant_id: + return None + + if not self._can_access_workflow(workflow_id, str(tenant_id)): + logger.warning( + "Workflow collaboration join rejected: workflow_id=%s tenant_id=%s user_id=%s sid=%s", + workflow_id, + tenant_id, + user_id, + sid, + ) + return None + + session_info: WorkflowSessionInfo = { + "user_id": str(user_id), + "username": str(session.get("username", "Unknown")), + "avatar": session.get("avatar"), + "sid": sid, + "connected_at": int(time.time()), + } + + self._repository.set_session_info(workflow_id, session_info) + + leader_sid = self.get_or_set_leader(workflow_id, sid) + is_leader = leader_sid == sid + + self._socketio.enter_room(sid, workflow_id) + self.broadcast_online_users(workflow_id) + + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + + return str(user_id), is_leader + + def _can_access_workflow(self, workflow_id: str, tenant_id: str) -> bool: + """Check room access without relying on Flask's app-context-bound scoped session.""" + with session_factory.create_session() as session: + app_id = session.scalar(select(App.id).where(App.id == workflow_id, App.tenant_id == tenant_id).limit(1)) + return app_id is not None + + def disconnect_session(self, sid: str) -> None: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return + + workflow_id = mapping["workflow_id"] + self._repository.delete_session(workflow_id, sid) + + self.handle_leader_disconnect(workflow_id, sid) + self.broadcast_online_users(workflow_id) + + def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + user_id = mapping["user_id"] + self.refresh_session_state(workflow_id, sid) + + event_type = data.get("type") + event_data = data.get("data") + timestamp = data.get("timestamp", int(time.time())) + + if not event_type: + return {"msg": "invalid event type"}, 400 + + if event_type == "sync_request": + leader_sid = self._repository.get_current_leader(workflow_id) + target_sid: str | None + if leader_sid and self.is_session_active(workflow_id, leader_sid): + target_sid = leader_sid + else: + if leader_sid: + self._repository.delete_leader(workflow_id) + target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid) + if target_sid: + self._repository.set_leader(workflow_id, target_sid) + self.broadcast_leader_change(workflow_id, target_sid) + + if not target_sid: + return {"msg": "no_active_leader"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=target_sid, + ) + + return {"msg": "sync_request_forwarded"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=workflow_id, + skip_sid=sid, + ) + + return {"msg": "event_broadcasted"}, 200 + + def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + self.refresh_session_state(workflow_id, sid) + + self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid) + + return {"msg": "graph_update_broadcasted"}, 200 + + def get_or_set_leader(self, workflow_id: str, sid: str) -> str: + current_leader = self._repository.get_current_leader(workflow_id) + + if current_leader: + if self.is_session_active(workflow_id, current_leader): + return current_leader + self._repository.delete_session(workflow_id, current_leader) + self._repository.delete_leader(workflow_id) + + was_set = self._repository.set_leader_if_absent(workflow_id, sid) + + if was_set: + if current_leader: + self.broadcast_leader_change(workflow_id, sid) + return sid + + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader: + return current_leader + + return sid + + def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if not current_leader: + return + + if current_leader != disconnected_sid: + return + + new_leader_sid = self._select_graph_leader(workflow_id) + if new_leader_sid: + self._repository.set_leader(workflow_id, new_leader_sid) + self.broadcast_leader_change(workflow_id, new_leader_sid) + else: + self._repository.delete_leader(workflow_id) + + def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None: + for sid in self._repository.get_session_sids(workflow_id): + try: + is_leader = new_leader_sid is not None and sid == new_leader_sid + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + except Exception: + logging.exception("Failed to emit leader status to session %s", sid) + + def get_current_leader(self, workflow_id: str) -> str | None: + return self._repository.get_current_leader(workflow_id) + + def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + """Remove inactive sessions from storage and return active sessions only.""" + sessions = self._repository.list_sessions(workflow_id) + if not sessions: + return [] + + active_sessions: list[WorkflowSessionInfo] = [] + stale_sids: list[str] = [] + for session in sessions: + sid = session["sid"] + if self.is_session_active(workflow_id, sid): + active_sessions.append(session) + else: + stale_sids.append(sid) + + for sid in stale_sids: + self._repository.delete_session(workflow_id, sid) + + return active_sessions + + def broadcast_online_users(self, workflow_id: str) -> None: + users = self._prune_inactive_sessions(workflow_id) + users.sort(key=lambda x: x.get("connected_at") or 0) + + leader_sid = self.get_current_leader(workflow_id) + previous_leader = leader_sid + active_sids = {user["sid"] for user in users} + if leader_sid and leader_sid not in active_sids: + self._repository.delete_leader(workflow_id) + leader_sid = None + + if not leader_sid and users: + leader_sid = self._select_graph_leader(workflow_id) + if leader_sid: + self._repository.set_leader(workflow_id, leader_sid) + + if leader_sid != previous_leader: + self.broadcast_leader_change(workflow_id, leader_sid) + + self._socketio.emit( + "online_users", + {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, + room=workflow_id, + ) + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + self._repository.refresh_session_state(workflow_id, sid) + self._ensure_leader(workflow_id, sid) + + def _ensure_leader(self, workflow_id: str, sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader and self.is_session_active(workflow_id, current_leader): + self._repository.expire_leader(workflow_id) + return + + if current_leader: + self._repository.delete_leader(workflow_id) + + self._repository.set_leader(workflow_id, sid) + self.broadcast_leader_change(workflow_id, sid) + + def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None: + session_sids = [ + session["sid"] + for session in self._repository.list_sessions(workflow_id) + if session.get("graph_active", True) and self.is_session_active(workflow_id, session["sid"]) + ] + if not session_sids: + return None + if preferred_sid and preferred_sid in session_sids: + return preferred_sid + return session_sids[0] + + def is_session_active(self, workflow_id: str, sid: str) -> bool: + if not sid: + return False + + try: + if not self._socketio.manager.is_connected(sid, "/"): + return False + except AttributeError: + return False + + if not self._repository.session_exists(workflow_id, sid): + return False + + if not self._repository.sid_mapping_exists(sid): + return False + + return True diff --git a/api/services/workflow_comment_service.py b/api/services/workflow_comment_service.py new file mode 100644 index 0000000000..ff47e4f253 --- /dev/null +++ b/api/services/workflow_comment_service.py @@ -0,0 +1,564 @@ +import logging +from collections.abc import Sequence + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session, selectinload +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from libs.helper import uuid_value +from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply +from models.account import Account +from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task + +logger = logging.getLogger(__name__) + + +class WorkflowCommentService: + """Service for managing workflow comments.""" + + @staticmethod + def _validate_content(content: str) -> None: + if len(content.strip()) == 0: + raise ValueError("Comment content cannot be empty") + + if len(content) > 1000: + raise ValueError("Comment content cannot exceed 1000 characters") + + @staticmethod + def _filter_valid_mentioned_user_ids( + mentioned_user_ids: Sequence[str], *, session: Session, tenant_id: str + ) -> list[str]: + """Return deduplicated UUID user IDs that belong to the tenant, preserving input order.""" + unique_user_ids: list[str] = [] + seen: set[str] = set() + for user_id in mentioned_user_ids: + if not isinstance(user_id, str): + continue + if not uuid_value(user_id): + continue + if user_id in seen: + continue + seen.add(user_id) + unique_user_ids.append(user_id) + if not unique_user_ids: + return [] + + tenant_member_ids = { + str(account_id) + for account_id in session.scalars( + select(TenantAccountJoin.account_id).where( + TenantAccountJoin.tenant_id == tenant_id, + TenantAccountJoin.account_id.in_(unique_user_ids), + ) + ).all() + } + + return [user_id for user_id in unique_user_ids if user_id in tenant_member_ids] + + @staticmethod + def _format_comment_excerpt(content: str, max_length: int = 200) -> str: + """Trim comment content for email display.""" + trimmed = content.strip() + if len(trimmed) <= max_length: + return trimmed + if max_length <= 3: + return trimmed[:max_length] + return f"{trimmed[: max_length - 3].rstrip()}..." + + @staticmethod + def _build_mention_email_payloads( + session: Session, + tenant_id: str, + app_id: str, + mentioner_id: str, + mentioned_user_ids: Sequence[str], + content: str, + ) -> list[dict[str, str]]: + """Prepare email payloads for mentioned users, including workflow app link.""" + if not mentioned_user_ids: + return [] + + candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id] + if not candidate_user_ids: + return [] + + app_name_value = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) + app_name = app_name_value if isinstance(app_name_value, str) and app_name_value else "Dify app" + commenter_name_value = session.scalar(select(Account.name).where(Account.id == mentioner_id)) + commenter_name = ( + commenter_name_value if isinstance(commenter_name_value, str) and commenter_name_value else "Dify user" + ) + comment_excerpt = WorkflowCommentService._format_comment_excerpt(content) + base_url = dify_config.CONSOLE_WEB_URL.rstrip("/") + app_url = f"{base_url}/app/{app_id}/workflow" + + accounts = session.scalars( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids)) + ).all() + + payloads: list[dict[str, str]] = [] + for account in accounts: + email = account.email + if not isinstance(email, str) or not email: + continue + mentioned_name = account.name if isinstance(account.name, str) and account.name else email + language = ( + account.interface_language + if isinstance(account.interface_language, str) and account.interface_language + else "en-US" + ) + payloads.append( + { + "language": language, + "to": email, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_excerpt, + "app_url": app_url, + } + ) + return payloads + + @staticmethod + def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None: + """Enqueue mention notification emails.""" + for payload in payloads: + send_workflow_comment_mention_email_task.delay(**payload) + + @staticmethod + def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]: + """Get all comments for a workflow.""" + with Session(db.engine) as session: + # Get all comments with eager loading + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id) + .order_by(desc(WorkflowComment.created_at)) + ) + + comments = session.scalars(stmt).all() + + # Batch preload all Account objects to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, comments) + + return comments + + @staticmethod + def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None: + """Batch preload Account objects for comments, replies, and mentions.""" + # Collect all user IDs + user_ids: set[str] = set() + for comment in comments: + user_ids.add(comment.created_by) + if comment.resolved_by: + user_ids.add(comment.resolved_by) + user_ids.update(reply.created_by for reply in comment.replies) + user_ids.update(mention.mentioned_user_id for mention in comment.mentions) + + if not user_ids: + return + + # Batch query all accounts + accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all() + account_map = {str(account.id): account for account in accounts} + + # Cache accounts on objects + for comment in comments: + comment.cache_created_by_account(account_map.get(comment.created_by)) + comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None) + for reply in comment.replies: + reply.cache_created_by_account(account_map.get(reply.created_by)) + for mention in comment.mentions: + mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id)) + + @staticmethod + def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment: + """Get a specific comment.""" + + def _get_comment(session: Session) -> WorkflowComment: + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Preload accounts to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, [comment]) + + return comment + + if session is not None: + return _get_comment(session) + else: + with Session(db.engine, expire_on_commit=False) as session: + return _get_comment(session) + + @staticmethod + def create_comment( + tenant_id: str, + app_id: str, + created_by: str, + content: str, + position_x: float, + position_y: float, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Create a new workflow comment and send mention notification emails.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine) as session: + comment = WorkflowComment( + tenant_id=tenant_id, + app_id=app_id, + position_x=position_x, + position_y=position_y, + content=content, + created_by=created_by, + ) + + session.add(comment) + session.flush() # Get the comment ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=tenant_id, + ) + for user_id in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention, not reply mention + mentioned_user_id=user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + # Return only what we need - id and created_at + return {"id": comment.id, "created_at": comment.created_at} + + @staticmethod + def update_comment( + tenant_id: str, + app_id: str, + comment_id: str, + user_id: str, + content: str, + position_x: float | None = None, + position_y: float | None = None, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a workflow comment and notify newly mentioned users. + + `mentioned_user_ids=None` means "leave mentions unchanged". + Passing an explicit list replaces the existing comment mentions, including clearing them with `[]`. + """ + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Get comment with validation + stmt = select(WorkflowComment).where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Only the creator can update the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can update it") + + # Update comment fields + comment.content = content + if position_x is not None: + comment.position_x = position_x + if position_y is not None: + comment.position_y = position_y + + mention_email_payloads: list[dict[str, str]] = [] + if mentioned_user_ids is not None: + # Replace comment mentions only when the client explicitly sends the mention list. + existing_mentions = session.scalars( + select(WorkflowCommentMention).where( + WorkflowCommentMention.comment_id == comment.id, + WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions + ) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + filtered_mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids, + session=session, + tenant_id=tenant_id, + ) + new_mentioned_user_ids = [ + mentioned_user_id + for mentioned_user_id in filtered_mentioned_user_ids + if mentioned_user_id not in existing_mentioned_user_ids + ] + for mentioned_user_id in filtered_mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention + mentioned_user_id=mentioned_user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": comment.id, "updated_at": comment.updated_at} + + @staticmethod + def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None: + """Delete a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + + # Only the creator can delete the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can delete it") + + # Delete associated mentions (both comment and reply mentions) + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id) + ).all() + for mention in mentions: + session.delete(mention) + + # Delete associated replies + replies = session.scalars( + select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id) + ).all() + for reply in replies: + session.delete(reply) + + session.delete(comment) + session.commit() + + @staticmethod + def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment: + """Resolve a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + if comment.resolved: + return comment + + comment.resolved = True + comment.resolved_at = naive_utc_now() + comment.resolved_by = user_id + session.commit() + + return comment + + @staticmethod + def create_reply( + comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None + ) -> dict: + """Add a reply to a workflow comment and notify mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Check if comment exists + comment = session.get(WorkflowComment, comment_id) + if not comment: + raise NotFound("Comment not found") + + reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by) + + session.add(reply) + session.flush() # Get the reply ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=comment.tenant_id, + ) + for user_id in mentioned_user_ids: + # Create mention linking to specific reply + mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "created_at": reply.created_at} + + @staticmethod + def _get_reply_in_comment_scope( + *, + session: Session, + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + ) -> WorkflowCommentReply: + """Get a reply scoped to tenant/app/comment to prevent cross-thread mutations.""" + stmt = ( + select(WorkflowCommentReply) + .join(WorkflowComment, WorkflowComment.id == WorkflowCommentReply.comment_id) + .where( + WorkflowCommentReply.id == reply_id, + WorkflowCommentReply.comment_id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + .limit(1) + ) + reply = session.scalar(stmt) + if not reply: + raise NotFound("Reply not found") + return reply + + @staticmethod + def update_reply( + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + user_id: str, + content: str, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a comment reply and notify newly mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can update the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can update it") + + reply.content = content + + # Update mentions - first remove existing mentions for this reply + existing_mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + # Add mentions + raw_mentioned_user_ids = mentioned_user_ids or [] + comment = session.get(WorkflowComment, reply.comment_id) + mentioned_user_ids = [] + if comment: + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + raw_mentioned_user_ids, + session=session, + tenant_id=comment.tenant_id, + ) + new_mentioned_user_ids = [ + user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids + ] + for user_id_str in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str + ) + session.add(mention) + + mention_email_payloads: list[dict[str, str]] = [] + if comment: + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + session.refresh(reply) # Refresh to get updated timestamp + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "updated_at": reply.updated_at} + + @staticmethod + def delete_reply(tenant_id: str, app_id: str, comment_id: str, reply_id: str, user_id: str) -> None: + """Delete a comment reply.""" + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can delete the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can delete it") + + # Delete associated mentions first + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id) + ).all() + for mention in mentions: + session.delete(mention) + + session.delete(reply) + session.commit() + + @staticmethod + def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment: + """Validate that a comment belongs to the specified tenant and app.""" + return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 2cc6e21574..96f936ff9b 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -7,19 +7,6 @@ from datetime import datetime from enum import StrEnum from typing import Any, ClassVar, NotRequired, TypedDict -from graphon.enums import NodeType -from graphon.file import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -40,6 +27,19 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation @@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -1067,7 +1067,7 @@ class DraftVariableSaver: filename = f"{self._generate_filename(name)}.txt" else: # For other types, store as JSON - original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False) + original_content_serialized = dumps_with_segments(value_seg.value) content_type = "application/json" filename = f"{self._generate_filename(name)}.json" diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 601e9261fc..94f88f8c49 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,15 +9,12 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker from core.app.apps.message_generator import MessageGenerator from core.app.entities.task_entities import ( + HumanInputRequiredResponse, MessageReplaceStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, @@ -26,6 +23,14 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter +from models.human_input import HumanInputForm from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot @@ -59,8 +64,10 @@ def build_workflow_event_stream( tenant_id: str, app_id: str, session_maker: sessionmaker[Session], + human_input_surface: HumanInputSurface | None = None, idle_timeout: float = 300, ping_interval: float = 10.0, + close_on_pause: bool = True, ) -> Generator[Mapping[str, Any] | str, None, None]: topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id) workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @@ -115,13 +122,15 @@ def build_workflow_event_stream( message_context=message_context, pause_entity=pause_entity, resumption_context=resumption_context, + session_maker=session_maker, + human_input_surface=human_input_surface, ) for event in snapshot_events: last_msg_time = time.time() last_ping_time = last_msg_time yield event - if _is_terminal_event(event, include_paused=True): + if _is_terminal_event(event, close_on_pause=close_on_pause): return while True: @@ -146,7 +155,7 @@ def build_workflow_event_stream( last_msg_time = time.time() last_ping_time = last_msg_time yield event - if _is_terminal_event(event, include_paused=True): + if _is_terminal_event(event, close_on_pause=close_on_pause): return finally: buffer_state.stop_event.set() @@ -207,6 +216,8 @@ def _build_snapshot_events( message_context: MessageContext | None, pause_entity: WorkflowPauseEntity | None, resumption_context: WorkflowResumptionContext | None, + session_maker: sessionmaker[Session] | None = None, + human_input_surface: HumanInputSurface | None = None, ) -> list[Mapping[str, Any]]: events: list[Mapping[str, Any]] = [] @@ -241,12 +252,24 @@ def _build_snapshot_events( events.append(node_finished) if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None: + for human_input_event in _build_human_input_required_events( + workflow_run_id=workflow_run.id, + task_id=task_id, + pause_entity=pause_entity, + session_maker=session_maker, + human_input_surface=human_input_surface, + ): + _apply_message_context(human_input_event, message_context) + events.append(human_input_event) + pause_event = _build_pause_event( workflow_run=workflow_run, workflow_run_id=workflow_run.id, task_id=task_id, pause_entity=pause_entity, resumption_context=resumption_context, + session_maker=session_maker, + human_input_surface=human_input_surface, ) if pause_event is not None: _apply_message_context(pause_event, message_context) @@ -314,6 +337,97 @@ def _build_node_started_event( return response.to_ignore_detail_dict() +def _build_human_input_required_events( + *, + workflow_run_id: str, + task_id: str, + pause_entity: WorkflowPauseEntity, + session_maker: sessionmaker[Session] | None, + human_input_surface: HumanInputSurface | None, +) -> list[dict[str, Any]]: + reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] + human_input_form_ids = [ + form_id + for reason in reasons + if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED + for form_id in [reason.get("form_id")] + if isinstance(form_id, str) + ] + + expiration_times_by_form_id: dict[str, int] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_tokens_by_form_id: dict[str, str] = {} + if human_input_form_ids and session_maker is not None: + stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where( + HumanInputForm.id.in_(human_input_form_ids) + ) + with session_maker() as session: + for form_id, expiration_time, form_definition in session.execute(stmt): + expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp()) + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_tokens_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=human_input_surface, + ) + + events: list[dict[str, Any]] = [] + for reason in reasons: + if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED: + continue + + form_id_raw = reason.get("form_id") + node_id_raw = reason.get("node_id") + node_title_raw = reason.get("node_title") + form_content_raw = reason.get("form_content") + if not isinstance(form_id_raw, str): + continue + if not isinstance(node_id_raw, str): + continue + if not isinstance(node_title_raw, str): + continue + if not isinstance(form_content_raw, str): + continue + form_id = form_id_raw + node_id = node_id_raw + node_title = node_title_raw + form_content = form_content_raw + + inputs = reason.get("inputs") + actions = reason.get("actions") + resolved_default_values = reason.get("resolved_default_values") + + expiration_time = expiration_times_by_form_id.get(form_id) + if expiration_time is None: + continue + + response = HumanInputRequiredResponse( + task_id=task_id, + workflow_run_id=workflow_run_id, + data=HumanInputRequiredResponse.Data( + form_id=form_id, + node_id=node_id, + node_title=node_title, + form_content=form_content, + inputs=inputs if isinstance(inputs, list) else [], + actions=actions if isinstance(actions, list) else [], + display_in_ui=display_in_ui_by_form_id.get(form_id, False), + form_token=form_tokens_by_form_id.get(form_id), + resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}), + expiration_time=expiration_time, + ), + ) + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + events.append(payload) + + return events + + def _build_node_finished_event( *, workflow_run_id: str, @@ -356,6 +470,8 @@ def _build_pause_event( task_id: str, pause_entity: WorkflowPauseEntity, resumption_context: WorkflowResumptionContext | None, + session_maker: sessionmaker[Session] | None, + human_input_surface: HumanInputSurface | None = None, ) -> dict[str, Any] | None: paused_nodes: list[str] = [] outputs: dict[str, Any] = {} @@ -365,6 +481,36 @@ def _build_pause_event( outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})) reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] + human_input_form_ids = [ + form_id + for reason in reasons + if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED + for form_id in [reason.get("form_id")] + if isinstance(form_id, str) + ] + form_tokens_by_form_id: dict[str, str] = {} + expiration_times_by_form_id: dict[str, int] = {} + if human_input_form_ids and session_maker is not None: + with session_maker() as session: + form_tokens_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=human_input_surface, + ) + stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( + HumanInputForm.id.in_(human_input_form_ids) + ) + for row in session.execute(stmt): + form_id, expiration_time, *_rest = row + expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp()) + # Reconnect paths must preserve the same pause-reason contract as live streams; + # otherwise clients see schema drift after resume. + reasons = enrich_human_input_pause_reasons( + reasons, + form_tokens_by_form_id=form_tokens_by_form_id, + expiration_times_by_form_id=expiration_times_by_form_id, + ) + response = WorkflowPauseStreamResponse( task_id=task_id, workflow_run_id=workflow_run_id, @@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None: return event -def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool: +def _is_terminal_event( + event: Mapping[str, Any] | str, + close_on_pause: bool = True, + *, + include_paused: bool | None = None, +) -> bool: + if include_paused is not None: + close_on_pause = include_paused if not isinstance(event, Mapping): return False event_type = event.get("event") if event_type == StreamEvent.WORKFLOW_FINISHED.value: return True - if include_paused: + if close_on_pause: return event_type == StreamEvent.WORKFLOW_PAUSED.value return False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 839b9e3319..f97b85dc2b 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,40 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast +from sqlalchemy import exists, select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.entities import PluginCredentialType +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager +from core.repositories import DifyCoreRepositoryFactory +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.trigger.constants import is_trigger_node_type +from core.workflow.human_input_adapter import ( + DeliveryChannelConfig, + adapt_human_input_node_data_for_graph, + parse_human_input_delivery_methods, +) +from core.workflow.node_factory import ( + LATEST_VERSION, + DifyGraphInitContext, + get_node_type_classes_mapping, + is_start_node_type, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace +from enums.cloud_plan import CloudPlan +from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated +from extensions.ext_database import db +from extensions.ext_storage import storage +from factories.file_factory import build_from_mapping, build_from_mappings from graphon.entities import WorkflowNodeExecution from graphon.entities.graph_config import NodeConfigDict from graphon.entities.pause_reason import HumanInputRequired @@ -30,40 +64,6 @@ from graphon.variable_loader import load_into_variable_pool from graphon.variables import VariableBase from graphon.variables.input_entities import VariableEntityType from graphon.variables.variables import Variable -from sqlalchemy import exists, select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.app.file_access import DatabaseFileAccessController -from core.entities import PluginCredentialType -from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager -from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.trigger.constants import is_trigger_node_type -from core.workflow.human_input_compat import ( - DeliveryChannelConfig, - normalize_human_input_node_data_for_graph, - parse_human_input_delivery_methods, -) -from core.workflow.node_factory import ( - LATEST_VERSION, - DifyGraphInitContext, - get_node_type_classes_mapping, - is_start_node_type, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from core.workflow.workflow_entry import WorkflowEntry -from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace -from enums.cloud_plan import CloudPlan -from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated -from extensions.ext_database import db -from extensions.ext_storage import storage -from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType @@ -156,11 +156,18 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: + def get_published_workflow_by_id( + self, app_model: App, workflow_id: str, session: Session | None = None + ) -> Workflow | None: """ fetch published workflow by workflow_id + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, @@ -178,16 +185,20 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Workflow | None: + def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None: """ Get published workflow + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ if not app_model.workflow_id: return None - # fetch published workflow by workflow_id - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, @@ -199,6 +210,16 @@ class WorkflowService: return workflow + def get_accessible_app_ids(self, app_ids: Sequence[str], tenant_id: str) -> set[str]: + """ + Return app IDs that belong to the given tenant. + """ + if not app_ids: + return set() + + stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id) + return {str(app_id) for app_id in db.session.scalars(stmt).all()} + def get_all_published_workflow( self, *, @@ -241,8 +262,8 @@ class WorkflowService: self, *, app_model: App, - graph: dict, - features: dict, + graph: dict[str, Any], + features: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -296,6 +317,78 @@ class WorkflowService: # return draft workflow return workflow + def update_draft_workflow_environment_variables( + self, + *, + app_model: App, + environment_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow environment variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.environment_variables = environment_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_conversation_variables( + self, + *, + app_model: App, + conversation_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow conversation variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.conversation_variables = conversation_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_features( + self, + *, + app_model: App, + features: dict, + account: Account, + ): + """ + Update draft workflow features + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + # validate features structure + self.validate_features_structure(app_model=app_model, features=features) + + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + def restore_published_workflow_to_draft( self, *, @@ -576,7 +669,7 @@ class WorkflowService: except Exception as e: raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") - def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: + def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None: """ Validate load balancing credentials for a workflow node. @@ -709,7 +802,7 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -1014,7 +1107,7 @@ class WorkflowService: raise ValueError("Node type must be human-input.") node_data = HumanInputNodeData.model_validate( - normalize_human_input_node_data_for_graph(node_config["data"]), + adapt_human_input_node_data_for_graph(node_config["data"]), from_attributes=True, ) delivery_method = self._resolve_human_input_delivery_method( @@ -1155,9 +1248,10 @@ class WorkflowService: variable_pool=variable_pool, start_at=time.perf_counter(), ) + node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( - id=node_config["id"], - config=node_config, + node_id=node_config["id"], + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=DifyHumanInputNodeRuntime(run_context), @@ -1214,7 +1308,7 @@ class WorkflowService: return variable_pool def run_free_workflow_node( - self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] + self, node_data: dict[str, Any], tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: """ Run free workflow node @@ -1361,7 +1455,7 @@ class WorkflowService: node_execution.status = WorkflowNodeExecutionStatus.FAILED node_execution.error = error - def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: + def convert_to_workflow(self, app_model: App, account: Account, args: dict[str, Any]) -> App: """ Basic mode of chatbot app(expert mode) to workflow Completion App to Workflow App @@ -1421,7 +1515,7 @@ class WorkflowService: if node_type == BuiltinNodeTypes.HUMAN_INPUT: self._validate_human_input_node_data(node_data) - def validate_features_structure(self, app_model: App, features: dict): + def validate_features_structure(self, app_model: App, features: dict[str, Any]): match app_model.mode: case AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( @@ -1434,7 +1528,7 @@ class WorkflowService: case _: raise ValueError(f"Invalid app mode: {app_model.mode}") - def _validate_human_input_node_data(self, node_data: dict) -> None: + def _validate_human_input_node_data(self, node_data: dict[str, Any]) -> None: """ Validate HumanInput node data format. @@ -1447,12 +1541,12 @@ class WorkflowService: from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) + HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 8f2f5f261e..5ceeb302c8 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,7 +7,6 @@ from typing import Annotated, Any from celery import shared_task from flask import current_app, json -from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -23,6 +22,7 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -399,6 +399,8 @@ def _resume_advanced_chat( workflow_run_id: str, workflow_run: WorkflowRun, ) -> None: + resumed_generate_entity = generate_entity.model_copy(update={"stream": True}) + try: triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) except ValueError: @@ -426,7 +428,7 @@ def _resume_advanced_chat( user=user, conversation=conversation, message=message, - application_generate_entity=generate_entity, + application_generate_entity=resumed_generate_entity, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, graph_runtime_state=graph_runtime_state, @@ -436,9 +438,8 @@ def _resume_advanced_chat( logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id) raise - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) def _resume_workflow( @@ -455,6 +456,8 @@ def _resume_workflow( workflow_run_repo, pause_entity, ) -> None: + resumed_generate_entity = generate_entity.model_copy(update={"stream": True}) + try: triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) except ValueError: @@ -480,7 +483,7 @@ def _resume_workflow( app_model=app_model, workflow=workflow, user=user, - application_generate_entity=generate_entity, + application_generate_entity=resumed_generate_entity, graph_runtime_state=graph_runtime_state, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, @@ -490,11 +493,18 @@ def _resume_workflow( logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id) raise - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) - workflow_run_repo.delete_workflow_pause(pause_entity) + try: + workflow_run_repo.delete_workflow_pause(pause_entity) + except Exception as exc: + if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc): + raise + logger.info( + "Skipped deleting workflow pause %s after resume because it was already replaced or removed", + pause_entity.id, + ) @shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution") diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 9ff34c7c48..5809268992 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -10,7 +10,6 @@ from datetime import UTC, datetime from typing import Any, NotRequired from celery import shared_task -from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from typing_extensions import TypedDict @@ -24,6 +23,7 @@ from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 4db551c73c..beb23d8354 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -8,7 +8,6 @@ from typing import Any import click import pandas as pd from celery import shared_task -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.db.session_factory import session_factory @@ -16,6 +15,7 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index a657cd553a..c8d0e31c06 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -61,13 +61,31 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i # check segment is exist if index_node_ids: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - with session_factory.create_session() as session: - dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) - if dataset: - index_processor.clean( - dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + # Wrap vector / keyword index cleanup in try/except so that a transient + # failure here (e.g. billing API hiccup propagated via FeatureService when + # ModelManager is initialized inside ``Vector(dataset)``) does not abort + # the entire task and leave document_segments / child_chunks / image_files + # / metadata bindings stranded in PG. Mirrors the pattern already used in + # ``clean_dataset_task`` so the document row's hard delete (already + # committed by the caller) does not produce orphan PG rows just because + # the vector backend or one of its transitive dependencies was unhappy. + try: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + with session_factory.create_session() as session: + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if dataset: + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector / keyword index in clean_document_task, " + "document_id=%s, dataset_id=%s, index_node_ids_count=%d. " + "Continuing with PG / storage cleanup; vector orphans can be reaped later.", + document_id, + dataset_id, + len(index_node_ids), + ) total_image_files = [] with session_factory.create_session() as session, session.begin(): diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index e3be24ac74..017d60efac 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -40,12 +40,29 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() total_index_node_ids.extend([segment.index_node_id for segment in segments]) - with session_factory.create_session() as session: - dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) - if dataset: - index_processor.clean( - dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + # Wrap vector / keyword index cleanup in try/except so that a transient + # failure here (e.g. billing API hiccup propagated via FeatureService when + # ``ModelManager`` is initialized inside ``Vector(dataset)``) does not abort + # the task and leave the already-deleted documents' segments stranded in PG. + # The Document rows are hard-deleted in the previous session block, so any + # exception escaping this task would produce orphans that no later request + # can reference back. Mirrors the pattern in ``clean_dataset_task``. + try: + with session_factory.create_session() as session: + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if dataset: + index_processor.clean( + dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector / keyword index in clean_notion_document_task, " + "dataset_id=%s, document_ids=%s, index_node_ids_count=%d. " + "Continuing with segment deletion; vector orphans can be reaped later.", + dataset_id, + document_ids, + len(total_index_node_ids), + ) with session_factory.create_session() as session, session.begin(): segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index ca73b4d374..fd743205a1 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,8 +2,6 @@ import logging from datetime import timedelta from celery import shared_task -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -11,6 +9,8 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index a316eec7b9..2a60be7762 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,15 +6,15 @@ from typing import Any import click from celery import shared_task -from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/mail_workflow_comment_task.py b/api/tasks/mail_workflow_comment_task.py new file mode 100644 index 0000000000..36d51f0514 --- /dev/null +++ b/api/tasks/mail_workflow_comment_task.py @@ -0,0 +1,65 @@ +import logging +import time + +import click +from celery import shared_task + +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_workflow_comment_mention_email_task( + language: str, + to: str, + mentioned_name: str, + commenter_name: str, + app_name: str, + comment_content: str, + app_url: str, +): + """ + Send workflow comment mention email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + mentioned_name: Name of the mentioned user + commenter_name: Name of the comment author + app_name: Name of the app where the comment was made + comment_content: Comment content excerpt + app_url: Link to the app workflow page + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.WORKFLOW_COMMENT_MENTION, + language_code=language, + to=to, + template_context={ + "to": to, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_content, + "app_url": app_url, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("workflow comment mention email to %s failed", to) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 5d201bd801..48d1774ce3 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -11,6 +11,7 @@ from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy +from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) @@ -171,14 +172,13 @@ def process_tenant_plugin_autoupgrade_check_task( fg="green", ) ) - _ = manager.upgrade_plugin( + # Use the service that downloads and uploads the package to the daemon + # first; calling manager.upgrade_plugin directly skips that step and the + # daemon fails because the package never reaches its local bucket. + _ = PluginService.upgrade_plugin_with_marketplace( tenant_id, original_unique_identifier, new_unique_identifier, - PluginInstallationSource.Marketplace, - { - "plugin_unique_identifier": new_unique_identifier, - }, ) except Exception as e: click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red")) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 56626e372e..8505375b6a 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,7 +12,6 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -28,7 +27,8 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType +from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, @@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -258,59 +259,58 @@ def dispatch_triggered_workflow( tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) trigger_entity: TriggerProviderEntity = provider_controller.entity + + # Ensure expire_on_commit is set to False to remain workflows available with session_factory.create_session() as session: workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers) - end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( - type=InvokeFrom.TRIGGER, - tenant_id=subscription.tenant_id, - app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], - user_id=user_id, - ) - for plugin_trigger in subscribers: - # Get workflow from mapping - workflow: Workflow | None = workflows.get(plugin_trigger.app_id) - if not workflow: - logger.error( - "Workflow not found for app %s", - plugin_trigger.app_id, - ) - continue + end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( + type=InvokeFrom.TRIGGER, + tenant_id=subscription.tenant_id, + app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], + user_id=user_id, + ) - # Find the trigger node in the workflow - event_node = None - for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): - if node_id == plugin_trigger.node_id: - event_node = node_config - break - - if not event_node: - logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) - continue - - # invoke trigger - trigger_metadata = PluginTriggerMetadata( - plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", - endpoint_id=subscription.endpoint_id, - provider_id=subscription.provider_id, - event_name=event_name, - icon_filename=trigger_entity.identity.icon or "", - icon_dark_filename=trigger_entity.identity.icon_dark or "", + for plugin_trigger in subscribers: + workflow: Workflow | None = workflows.get(plugin_trigger.app_id) + if not workflow: + logger.error( + "Workflow not found for app %s", + plugin_trigger.app_id, ) + continue - # consume quota before invoking trigger - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) - logger.info( - "Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id - ) - return 0 + event_node = None + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): + if node_id == plugin_trigger.node_id: + event_node = node_config + break - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) - invoke_response: TriggerInvokeEventResponse | None = None + if not event_node: + logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) + continue + + trigger_metadata = PluginTriggerMetadata( + plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", + endpoint_id=subscription.endpoint_id, + provider_id=subscription.provider_id, + event_name=event_name, + icon_filename=trigger_entity.identity.icon or "", + icon_dark_filename=trigger_entity.identity.icon_dark or "", + ) + + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) + logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id) + return dispatched_count + + node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) + invoke_response: TriggerInvokeEventResponse | None = None + + with session_factory.create_session() as session: try: invoke_response = TriggerManager.invoke_trigger_event( tenant_id=subscription.tenant_id, @@ -387,6 +387,7 @@ def dispatch_triggered_workflow( raise ValueError(f"End user not found for app {plugin_trigger.app_id}") AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data) + quota_charge.commit() dispatched_count += 1 logger.info( "Triggered workflow for app %s with trigger event %s", @@ -401,7 +402,7 @@ def dispatch_triggered_workflow( plugin_trigger.app_id, ) - return dispatched_count + return dispatched_count def dispatch_triggered_workflows( diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index b4f975f4da..5ca04fd7c2 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -10,11 +10,11 @@ import logging from typing import Any from celery import shared_task -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 128cdd72e1..0d5475a56d 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -10,13 +10,13 @@ import logging from typing import Any from celery import shared_task +from sqlalchemy import select + +from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter -from sqlalchemy import select - -from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 8c64d3ab27..7638652000 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import ( ScheduleNotFoundError, TenantOwnerNotFoundError, ) -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.schedule_service import ScheduleService from services.workflow.entities import ScheduleTriggerData @@ -32,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None: TenantOwnerNotFoundError: If no owner/admin for tenant ScheduleExecutionError: If workflow trigger fails """ + # Ensure expire_on_commit is set to False to remain schedule/tenant_owner available with session_factory.create_session() as session: schedule = session.get(WorkflowSchedulePlan, schedule_id) if not schedule: @@ -41,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None: if not tenant_owner: raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}") - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) - logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) - return + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) + logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) + return - try: - # Production dispatch: Trigger the workflow normally + try: + with session_factory.create_session() as session: response = AsyncWorkflowService.trigger_workflow_async( session=session, user=tenant_owner, @@ -61,9 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None: tenant_id=schedule.tenant_id, ), ) - logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) - except Exception as e: - quota_charge.refund() - raise ScheduleExecutionError( - f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" - ) from e + quota_charge.commit() + logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) + except Exception as e: + quota_charge.refund() + raise ScheduleExecutionError( + f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" + ) from e diff --git a/api/templates/without-brand/workflow_comment_mention_template_en-US.html b/api/templates/without-brand/workflow_comment_mention_template_en-US.html new file mode 100644 index 0000000000..1ef8fe4e3f --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 0000000000..8b9b2dbe71 --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_en-US.html b/api/templates/workflow_comment_mention_template_en-US.html new file mode 100644 index 0000000000..1ef8fe4e3f --- /dev/null +++ b/api/templates/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_zh-CN.html b/api/templates/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 0000000000..8b9b2dbe71 --- /dev/null +++ b/api/templates/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index f84d39aeb5..c07ab6d6bf 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -33,6 +33,7 @@ REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false REDIS_DB=0 +REDIS_KEY_PREFIX= # PostgreSQL database configuration DB_USERNAME=postgres diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index b2e8dda443..09078d196d 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -48,7 +48,7 @@ os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage" os.environ.setdefault("STORAGE_TYPE", "opendal") os.environ.setdefault("OPENDAL_SCHEME", "fs") -_CACHED_APP = create_app() +_SIO_APP, _CACHED_APP = create_app() @pytest.fixture(scope="session") diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index d10e5ed13c..3b5e822b90 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -171,35 +171,13 @@ class TestChatMessageApiPermissions: parent_message_id=None, ) - class MockQuery: - def __init__(self, model): - self.model = model - - def where(self, *args, **kwargs): - return self - - def first(self): - if getattr(self.model, "__name__", "") == "Conversation": - return mock_conversation - return None - - def order_by(self, *args, **kwargs): - return self - - def limit(self, *_): - return self - - def all(self): - if getattr(self.model, "__name__", "") == "Message": - return [mock_message] - return [] - mock_session = mock.Mock() - mock_session.query.side_effect = MockQuery - mock_session.scalar.return_value = False + mock_session.scalar.return_value = mock_conversation + mock_session.scalars.return_value.all.return_value = [mock_message] monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session)) monkeypatch.setattr(message_api, "current_user", mock_account) + monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock()) class DummyPagination: def __init__(self, data, limit, has_more): diff --git a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py deleted file mode 100644 index 038f37af5f..0000000000 --- a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ /dev/null @@ -1,47 +0,0 @@ -import uuid -from unittest import mock - -from controllers.console.app import workflow_draft_variable as draft_variable_api -from controllers.console.app import wraps -from factories.variable_factory import build_segment -from models import App, AppMode -from models.workflow import WorkflowDraftVariable -from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService - - -def _get_mock_srv_class() -> type[WorkflowDraftVariableService]: - return mock.create_autospec(WorkflowDraftVariableService) - - -class TestWorkflowDraftNodeVariableListApi: - def test_get(self, test_client, auth_header, monkeypatch): - srv_class = _get_mock_srv_class() - mock_app_model: App = App() - mock_app_model.id = str(uuid.uuid4()) - test_node_id = "test_node_id" - mock_app_model.mode = AppMode.ADVANCED_CHAT - mock_load_app_model = mock.Mock(return_value=mock_app_model) - - monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) - - var1 = WorkflowDraftVariable.new_node_variable( - app_id="test_app_1", - node_id="test_node_1", - name="str_var", - value=build_segment("str_value"), - node_execution_id=str(uuid.uuid4()), - ) - srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True) - srv_class.return_value = srv_instance - srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1]) - - response = test_client.get( - f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables", - headers=auth_header, - ) - assert response.status_code == 200 - response_dict = response.json - assert isinstance(response_dict, dict) - assert "items" in response_dict - assert len(response_dict["items"]) == 1 diff --git a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py deleted file mode 100644 index e55c12e678..0000000000 --- a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Integration tests for Trigger Provider subscription permission verification.""" - -import uuid -from unittest import mock - -import pytest -from flask.testing import FlaskClient - -from controllers.console.workspace import trigger_providers as trigger_providers_api -from libs.datetime_utils import naive_utc_now -from models import Tenant -from models.account import Account, TenantAccountJoin, TenantAccountRole - - -class TestTriggerProviderSubscriptionPermissions: - """Test permission verification for Trigger Provider subscription endpoints.""" - - @pytest.fixture - def mock_account(self, monkeypatch: pytest.MonkeyPatch): - """Create a mock Account for testing.""" - - account = Account(name="Test User", email="test@example.com") - account.id = str(uuid.uuid4()) - account.last_active_at = naive_utc_now() - account.created_at = naive_utc_now() - account.updated_at = naive_utc_now() - - # Create mock tenant - tenant = Tenant(name="Test Tenant") - tenant.id = str(uuid.uuid4()) - - mock_session_instance = mock.Mock() - - mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) - monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) - - mock_scalars_result = mock.Mock() - mock_scalars_result.one.return_value = tenant - monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) - - mock_session_context = mock.Mock() - mock_session_context.__enter__.return_value = mock_session_instance - monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) - - account.current_tenant = tenant - account.current_tenant_id = tenant.id - return account - - @pytest.mark.parametrize( - ("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"), - [ - # Admin/Owner can do everything - (TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200), - (TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200), - # Editor can list, get, update (parameters), but not create, build, or delete - (TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403), - # Normal user cannot do anything - (TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403), - # Dataset operator cannot do anything - (TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403), - ], - ) - def test_trigger_subscription_permissions( - self, - test_client: FlaskClient, - auth_header, - monkeypatch, - mock_account, - role: TenantAccountRole, - list_status: int, - get_status: int, - update_status: int, - create_status: int, - build_status: int, - delete_status: int, - ): - """Test that different roles have appropriate permissions for trigger subscription operations.""" - # Set user role - mock_account.role = role - - # Mock current user - monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) - - # Mock AccountService.load_user to prevent authentication issues - from services.account_service import AccountService - - mock_load_user = mock.Mock(return_value=mock_account) - monkeypatch.setattr(AccountService, "load_user", mock_load_user) - - # Test data - provider = "some_provider/some_trigger" - subscription_builder_id = str(uuid.uuid4()) - subscription_id = str(uuid.uuid4()) - - # Mock service methods - mock_list_subscriptions = mock.Mock(return_value=[]) - monkeypatch.setattr( - "services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions", - mock_list_subscriptions, - ) - - mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", - mock_get_subscription_builder, - ) - - mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", - mock_update_subscription_builder, - ) - - mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", - mock_create_subscription_builder, - ) - - mock_update_and_build_builder = mock.Mock() - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder", - mock_update_and_build_builder, - ) - - mock_delete_provider = mock.Mock() - mock_delete_plugin_trigger = mock.Mock() - mock_db_session = mock.Mock() - mock_db_session.commit = mock.Mock() - - def mock_session_func(engine=None): - return mock_session_context - - mock_session_context = mock.Mock() - mock_session_context.__enter__.return_value = mock_db_session - mock_session_context.__exit__.return_value = None - - monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func) - monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func) - - monkeypatch.setattr( - "services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider", - mock_delete_provider, - ) - monkeypatch.setattr( - "services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription", - mock_delete_plugin_trigger, - ) - - # Test 1: List subscriptions (should work for Editor, Admin, Owner) - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list", - headers=auth_header, - ) - assert response.status_code == list_status - - # Test 2: Get subscription builder (should work for Editor, Admin, Owner) - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}", - headers=auth_header, - ) - assert response.status_code == get_status - - # Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}", - headers=auth_header, - json={"parameters": {"webhook_url": "https://example.com/webhook"}}, - ) - assert response.status_code == update_status - - # Test 4: Create subscription builder (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create", - headers=auth_header, - json={"credential_type": "api_key"}, - ) - assert response.status_code == create_status - - # Test 5: Build/activate subscription (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}", - headers=auth_header, - json={"name": "Test Subscription"}, - ) - assert response.status_code == build_status - - # Test 6: Delete subscription (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete", - headers=auth_header, - ) - assert response.status_code == delete_status - - @pytest.mark.parametrize( - ("role", "status"), - [ - (TenantAccountRole.OWNER, 200), - (TenantAccountRole.ADMIN, 200), - # Editor should be able to access logs for debugging - (TenantAccountRole.EDITOR, 200), - (TenantAccountRole.NORMAL, 403), - (TenantAccountRole.DATASET_OPERATOR, 403), - ], - ) - def test_trigger_subscription_logs_permissions( - self, - test_client: FlaskClient, - auth_header, - monkeypatch, - mock_account, - role: TenantAccountRole, - status: int, - ): - """Test that different roles have appropriate permissions for accessing subscription logs.""" - # Set user role - mock_account.role = role - - # Mock current user - monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) - - # Mock AccountService.load_user to prevent authentication issues - from services.account_service import AccountService - - mock_load_user = mock.Mock(return_value=mock_account) - monkeypatch.setattr(AccountService, "load_user", mock_load_user) - - # Test data - provider = "some_provider/some_trigger" - subscription_builder_id = str(uuid.uuid4()) - - # Mock service method - mock_list_logs = mock.Mock(return_value=[]) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs", - mock_list_logs, - ) - - # Test access to logs - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}", - headers=auth_header, - ) - assert response.status_code == status diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 91245e879e..a876b0c4aa 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,9 +1,8 @@ from collections.abc import Generator -from graphon.node_events import StreamCompletedEvent - from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage +from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3fdea10976..2392084c36 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: @@ -70,19 +70,16 @@ def test_node_integration_minimal_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=_GP(), graph_runtime_state=_GS(vp), ) diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py deleted file mode 100644 index c1bb8e1245..0000000000 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ /dev/null @@ -1,375 +0,0 @@ -import unittest -from datetime import UTC, datetime -from unittest.mock import patch -from uuid import uuid4 - -import pytest -from graphon.file import File, FileTransferMethod, FileType -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from extensions.ext_database import db -from extensions.storage.storage_type import StorageType -from factories.file_factory import StorageKeyLoader -from models import ToolFile, UploadFile -from models.enums import CreatorUserRole - - -@pytest.mark.usefixtures("flask_req_ctx") -class TestStorageKeyLoader(unittest.TestCase): - """ - Integration tests for StorageKeyLoader class. - - Tests the batched loading of storage keys from the database for files - with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. - """ - - def setUp(self): - """Set up test data before each test method.""" - self.session = db.session() - self.tenant_id = str(uuid4()) - self.user_id = str(uuid4()) - self.conversation_id = str(uuid4()) - - # Create test data that will be cleaned up after each test - self.test_upload_files = [] - self.test_tool_files = [] - - # Create StorageKeyLoader instance - self.loader = StorageKeyLoader( - self.session, - self.tenant_id, - access_controller=DatabaseFileAccessController(), - ) - - def tearDown(self): - """Clean up test data after each test method.""" - self.session.rollback() - - def _create_upload_file( - self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None - ) -> UploadFile: - """Helper method to create an UploadFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if storage_key is None: - storage_key = f"test_storage_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - - upload_file = UploadFile( - tenant_id=tenant_id, - storage_type=StorageType.LOCAL, - key=storage_key, - name="test_file.txt", - size=1024, - extension=".txt", - mime_type="text/plain", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, - created_at=datetime.now(UTC), - used=False, - ) - upload_file.id = file_id - - self.session.add(upload_file) - self.session.flush() - self.test_upload_files.append(upload_file) - - return upload_file - - def _create_tool_file( - self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None - ) -> ToolFile: - """Helper method to create a ToolFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if file_key is None: - file_key = f"test_file_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - - tool_file = ToolFile( - user_id=self.user_id, - tenant_id=tenant_id, - conversation_id=self.conversation_id, - file_key=file_key, - mimetype="text/plain", - original_url="http://example.com/file.txt", - name="test_tool_file.txt", - size=2048, - ) - tool_file.id = file_id - self.session.add(tool_file) - self.session.flush() - self.test_tool_files.append(tool_file) - - return tool_file - - def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: - """Helper method to create a File object for testing.""" - if tenant_id is None: - tenant_id = self.tenant_id - - # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods - file_related_id = None - remote_url = None - - if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): - file_related_id = related_id - elif transfer_method == FileTransferMethod.REMOTE_URL: - remote_url = "https://example.com/test_file.txt" - file_related_id = related_id - - return File( - id=str(uuid4()), # Generate new UUID for File.id - tenant_id=tenant_id, - type=FileType.DOCUMENT, - transfer_method=transfer_method, - related_id=file_related_id, - remote_url=remote_url, - filename="test_file.txt", - extension=".txt", - mime_type="text/plain", - size=1024, - storage_key="initial_key", - ) - - def test_load_storage_keys_local_file(self): - """Test loading storage keys for LOCAL_FILE transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == upload_file.key - - def test_load_storage_keys_remote_url(self): - """Test loading storage keys for REMOTE_URL transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == upload_file.key - - def test_load_storage_keys_tool_file(self): - """Test loading storage keys for TOOL_FILE transfer method.""" - # Create test data - tool_file = self._create_tool_file() - file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == tool_file.file_key - - def test_load_storage_keys_mixed_methods(self): - """Test batch loading with mixed transfer methods.""" - # Create test data for different transfer methods - upload_file1 = self._create_upload_file() - upload_file2 = self._create_upload_file() - tool_file = self._create_tool_file() - - file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) - file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) - - files = [file1, file2, file3] - - # Load storage keys - self.loader.load_storage_keys(files) - - # Verify all storage keys were loaded correctly - assert file1._storage_key == upload_file1.key - assert file2._storage_key == upload_file2.key - assert file3._storage_key == tool_file.file_key - - def test_load_storage_keys_empty_list(self): - """Test with empty file list.""" - # Should not raise any exceptions - self.loader.load_storage_keys([]) - - def test_load_storage_keys_ignores_legacy_file_tenant_id(self): - """Legacy file tenant_id should not override the loader tenant scope.""" - upload_file = self._create_upload_file() - file = self._create_file( - related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) - ) - - self.loader.load_storage_keys([file]) - - assert file._storage_key == upload_file.key - - def test_load_storage_keys_missing_file_id(self): - """Test with None file.related_id.""" - # Create a file with valid parameters first, then manually set related_id to None - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) - file.related_id = None - - # Should raise ValueError for None file related_id - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) - - assert str(context.value) == "file id should not be None." - - def test_load_storage_keys_nonexistent_upload_file_records(self): - """Test with missing UploadFile database records.""" - # Create file with non-existent upload file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Should raise ValueError for missing record - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_nonexistent_tool_file_records(self): - """Test with missing ToolFile database records.""" - # Create file with non-existent tool file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) - - # Should raise ValueError for missing record - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_invalid_uuid(self): - """Test with invalid UUID format.""" - # Create a file with valid parameters first, then manually set invalid related_id - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) - file.related_id = "invalid-uuid-format" - - # Should raise ValueError for invalid UUID - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_batch_efficiency(self): - """Test batched operations use efficient queries.""" - # Create multiple files of different types - upload_files = [self._create_upload_file() for _ in range(3)] - tool_files = [self._create_tool_file() for _ in range(2)] - - files = [] - files.extend( - [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] - ) - files.extend( - [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] - ) - - # Mock the session to count queries - with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: - self.loader.load_storage_keys(files) - - # Should make exactly 2 queries (one for upload_files, one for tool_files) - assert mock_scalars.call_count == 2 - - # Verify all storage keys were loaded correctly - for i, file in enumerate(files[:3]): - assert file._storage_key == upload_files[i].key - for i, file in enumerate(files[3:]): - assert file._storage_key == tool_files[i].file_key - - def test_load_storage_keys_tenant_isolation(self): - """Test that tenant isolation works correctly.""" - # Create files for different tenants - other_tenant_id = str(uuid4()) - - # Create upload file for current tenant - upload_file_current = self._create_upload_file() - file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE - ) - - # Create upload file for other tenant (but don't add to cleanup list) - upload_file_other = UploadFile( - tenant_id=other_tenant_id, - storage_type=StorageType.LOCAL, - key="other_tenant_key", - name="other_file.txt", - size=1024, - extension=".txt", - mime_type="text/plain", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, - created_at=datetime.now(UTC), - used=False, - ) - upload_file_other.id = str(uuid4()) - self.session.add(upload_file_other) - self.session.flush() - - # Create file for other tenant but try to load with current tenant's loader - file_other = self._create_file( - related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id - ) - - # Should raise ValueError due to tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_other]) - - assert "Upload file not found for id:" in str(context.value) - - # Current tenant's file should still work - self.loader.load_storage_keys([file_current]) - assert file_current._storage_key == upload_file_current.key - - def test_load_storage_keys_mixed_tenant_batch(self): - """Test batch with mixed tenant files (should fail on first mismatch).""" - # Create files for current tenant - upload_file_current = self._create_upload_file() - file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE - ) - - # Create file for different tenant - other_tenant_id = str(uuid4()) - file_other = self._create_file( - related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id - ) - - # Should raise ValueError on tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_current, file_other]) - - assert "Upload file not found for id:" in str(context.value) - - def test_load_storage_keys_duplicate_file_ids(self): - """Test handling of duplicate file IDs in the batch.""" - # Create upload file - upload_file = self._create_upload_file() - - # Create two File objects with same related_id - file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Should handle duplicates gracefully - self.loader.load_storage_keys([file1, file2]) - - # Both files should have the same storage key - assert file1._storage_key == upload_file.key - assert file2._storage_key == upload_file.key - - def test_load_storage_keys_session_isolation(self): - """Test that the loader uses the provided session correctly.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Create loader with different session (same underlying connection) - - with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader( - other_session, - self.tenant_id, - access_controller=DatabaseFileAccessController(), - ) - with pytest.raises(ValueError): - other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index ce04a158a8..c4146d5ccd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,6 +4,9 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -23,9 +26,6 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index c7bb90f019..e130644338 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,10 +3,6 @@ import unittest import uuid import pytest -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable from sqlalchemy import delete, func, select from sqlalchemy.orm import Session @@ -15,6 +11,10 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 3dfedd811d..4f444598b1 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest -from graphon.variables.segments import StringSegment from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -209,7 +209,6 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -453,7 +452,6 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index c0143faa85..a9a2617bae 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,12 +1,11 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py deleted file mode 100644 index 487178ff58..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor - -CODE_LANGUAGE = "unsupported_language" - - -def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionError) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) - assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py deleted file mode 100644 index c8eb9ec3e4..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ /dev/null @@ -1,95 +0,0 @@ -import base64 - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.JINJA2 - - -def test_jinja2(): - """Test basic Jinja2 template rendering.""" - template = "Hello {{template}}" - # Template must be base64 encoded to match the new safe embedding approach - template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") - inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") - code = ( - Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64) - .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) - ) - result = CodeExecutor.execute_code( - language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code - ) - assert result == "<>Hello World<>\n" - - -def test_jinja2_with_code_template(): - """Test template rendering via the high-level workflow API.""" - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} - ) - assert result == {"result": "Hello World"} - - -def test_jinja2_get_runner_script(): - """Test that runner script contains required placeholders.""" - runner_script = Jinja2TemplateTransformer.get_runner_script() - assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1 - assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 - - -def test_jinja2_template_with_special_characters(): - """ - Test that templates with special characters (quotes, newlines) render correctly. - This is a regression test for issue #26818 where textarea pre-fill values - containing special characters would break template rendering. - """ - # Template with triple quotes, single quotes, double quotes, and newlines - template = """ - - - -

Status: "{{ status }}"

-
'''code block'''
- -""" - inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - # Verify the template rendered correctly with all special characters - output = result["result"] - assert 'value="TASK-123"' in output - assert "" in output - assert 'Status: "completed"' in output - assert "'''code block'''" in output - - -def test_jinja2_template_with_html_textarea_prefill(): - """ - Specific test for HTML textarea with Jinja2 variable pre-fill. - Verifies fix for issue #26818. - """ - template = "" - notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes." - inputs = {"notes": notes_content} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - expected_output = f"" - assert result["result"] == expected_output - - -def test_jinja2_assemble_runner_script_encodes_template(): - """Test that assemble_runner_script properly base64 encodes the template.""" - template = "Hello {{ name }}!" - inputs = {"name": "World"} - - script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs) - - # The template should be base64 encoded in the script - template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") - assert template_b64 in script - # The raw template should NOT appear in the script (it's encoded) - assert "Hello {{ name }}!" not in script diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py deleted file mode 100644 index 25af312afa..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ /dev/null @@ -1,36 +0,0 @@ -from textwrap import dedent - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.PYTHON3 - - -def test_python3_plain(): - code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == "Hello World\n" - - -def test_python3_json(): - code = dedent(""" - import json - print(json.dumps({'Hello': 'World'})) - """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == '{"Hello": "World"}\n' - - -def test_python3_with_code_template(): - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} - ) - assert result == {"result": "HelloWorld"} - - -def test_python3_get_runner_script(): - runner_script = Python3TemplateTransformer.get_runner_script() - assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4f41396c22..aaa6092993 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,18 @@ import time import uuid import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import NodeRunResult -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeNodeData +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.code_executor",) @@ -64,8 +65,8 @@ def init_code_node(code_config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( - id=str(uuid.uuid4()), - config=code_config, + node_id=str(uuid.uuid4()), + config=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b1f937e738..b9f7b9575b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,11 +3,6 @@ import uuid from urllib.parse import urlencode import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -16,6 +11,11 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.http",) @@ -75,8 +75,8 @@ def init_http_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -192,6 +192,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,8 +203,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool - from core.workflow.system_variables import build_system_variables - # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), @@ -724,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=graph_config["nodes"][1], + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index f0f3fcead1..3eead70163 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,20 +4,20 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.llm.node import LLMNode from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -76,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode: llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index fe512c2585..f2eabb86c3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,17 +3,17 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -70,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = ParameterExtractorNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 2d728569be..e2e0723fb8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,15 +1,15 @@ import time import uuid -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.template_rendering import TemplateRenderError - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -87,8 +87,8 @@ def test_execute_template_transform(): assert graph is not None node = TemplateTransformNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 750ced7075..a8e9422c1e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,18 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import StreamCompletedEvent -from graphon.nodes.protocols import ToolFileManagerProtocol -from graphon.nodes.tool.tool_node import ToolNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.entities import ToolNodeData +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -61,8 +61,8 @@ def init_tool_node(config: dict): tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef74893f07..66a25e5daf 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -369,7 +369,7 @@ def _create_app_with_containers() -> Flask: # Create and configure the Flask application logger.info("Initializing Flask application...") - app = create_app() + sio_app, app = create_app() logger.info("Flask application created successfully") # Initialize database schema diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 54e0496dbd..18755ef012 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -234,6 +234,35 @@ class TestAppEndpoints: } ) + def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch): + api = app_module.AppIconApi() + method = _unwrap(api.post) + payload = { + "icon": "https://example.com/icon.png", + "icon_type": "image", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app_icon.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetail, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace()) + + assert response == {"id": "app-1"} + assert app_service.update_app_icon.call_args.args[1:] == ( + payload["icon"], + payload["icon_background"], + app_module.IconType.IMAGE, + ) + class TestOpsTraceEndpoints: @pytest.fixture @@ -432,7 +461,7 @@ class TestWorkflowAppLogEndpoints: monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker) def fake_get_paginate(self, **_kwargs): - return {"items": [], "total": 0} + return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} monkeypatch.setattr( workflow_app_log_module.WorkflowAppService, @@ -443,7 +472,7 @@ class TestWorkflowAppLogEndpoints: with app.test_request_context("/?page=1&limit=20"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result == {"items": [], "total": 0} + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} class TestWorkflowDraftVariableEndpoints: @@ -608,7 +637,8 @@ class TestWorkflowTriggerEndpoints: with app.test_request_context("/?node_id=node-1"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is trigger + assert isinstance(result, dict) + assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys()) class TestWrapsEndpoints: diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index d8c6821f8d..25d19cf35a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -96,6 +96,56 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED + def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.commit.assert_called_once_with() + fake_session.rollback.assert_not_called() + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.rollback.assert_called_once_with() + fake_session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + class TestAppImportConfirmApi: @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5cc458fe2e..5a22f81a69 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,15 +4,15 @@ import json import uuid from flask.testing import FlaskClient -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun @@ -30,7 +30,7 @@ def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py new file mode 100644 index 0000000000..fad0b8b10e --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -0,0 +1,73 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.console.app.conversation import _get_conversation +from models.enums import ConversationFromSource +from models.model import AppMode, Conversation +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def test_get_conversation_mark_read_keeps_updated_at_unchanged( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + original_updated_at = datetime(2026, 2, 8, 0, 0, 0) + conversation = Conversation( + app_id=app.id, + name="read timestamp test", + inputs={}, + status="normal", + mode=AppMode.CHAT, + from_source=ConversationFromSource.CONSOLE, + from_account_id=account.id, + updated_at=original_updated_at, + ) + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + + read_at = datetime(2026, 2, 9, 0, 0, 0) + + with ( + patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ), + patch( + "controllers.console.app.conversation.naive_utc_now", + return_value=read_at, + autospec=True, + ), + ): + loaded = _get_conversation(app, conversation.id) + + db_session_with_containers.refresh(conversation) + + assert loaded.id == conversation.id + assert conversation.read_at == read_at + assert conversation.read_account_id == account.id + assert conversation.updated_at == original_updated_at + + +def test_get_conversation_raises_not_found_for_missing_conversation( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ): + with pytest.raises(NotFound): + _get_conversation(app, "00000000-0000-0000-0000-000000000000") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 8ddf867370..290be87697 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,12 +3,12 @@ import uuid from flask.testing import FlaskClient -from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable +from graphon.variables.segments import StringSegment from models import Workflow from models.model import AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py index 9e2084f393..a8ecf94da1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/helpers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -11,7 +11,7 @@ from constants import HEADER_NAME_CSRF_TOKEN from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.model import App, AppMode from services.account_service import AccountService @@ -37,7 +37,7 @@ def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Ten db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py new file mode 100644 index 0000000000..4e884626a7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py @@ -0,0 +1,110 @@ +""" +Testcontainers integration tests for Service API Site controller. +""" + +from __future__ import annotations + +import pytest +from flask import Flask +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.service_api.app.site import AppSiteApi +from models.account import Tenant, TenantStatus +from models.model import App, AppMode, Site + + +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _unwrap(method): + fn = method + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + return fn + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="service-api-site-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="service-api-site-app", + enable_site=True, + enable_api=True, + status="normal", + ) + db_session.add(app_model) + db_session.commit() + return app_model + + +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, + title="Service API Site", + icon_type="emoji", + icon="robot", + icon_background="#ffffff", + description="Service API test site", + default_language="en-US", + prompt_public=True, + show_workflow_steps=True, + customize_token_strategy="not_allow", + use_icon_as_answer_icon=False, + chat_color_theme="light", + chat_color_theme_inverted=False, + ) + db_session.add(site) + db_session.commit() + return site + + +class TestAppSiteApi: + def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + response = _unwrap(api.get)(api, app_model=app_model) + + assert response["title"] == "Service API Site" + assert response["icon"] == "robot" + assert response["description"] == "Service API test site" + + def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) + + def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + archived_tenant = db_session_with_containers.get(Tenant, tenant.id) + assert archived_tenant is not None + archived_tenant.status = TenantStatus.ARCHIVE + db_session_with_containers.commit() + + app_model = db_session_with_containers.get(App, app_model.id) + assert app_model is not None + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/test_containers_integration_tests/controllers/web/test_site.py similarity index 51% rename from api/tests/unit_tests/controllers/web/test_site.py rename to api/tests/test_containers_integration_tests/controllers/web/test_site.py index 6e9d754c43..9adb26ff3d 100644 --- a/api/tests/unit_tests/controllers/web/test_site.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_site.py @@ -1,28 +1,48 @@ -"""Unit tests for controllers.web.site endpoints.""" +"""Testcontainers integration tests for controllers.web.site endpoints.""" from __future__ import annotations from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from controllers.web.site import AppSiteApi, AppSiteInfo +from models import Tenant, TenantStatus +from models.model import App, AppMode, CustomizeTokenStrategy, Site -def _tenant(*, status: str = "normal") -> SimpleNamespace: - return SimpleNamespace( - id="tenant-1", - status=status, - plan="basic", - custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="test-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="test-app", + enable_site=enable_site, + enable_api=True, ) + db_session.add(app_model) + db_session.commit() + return app_model -def _site() -> SimpleNamespace: - return SimpleNamespace( +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, title="Site", icon_type="emoji", icon="robot", @@ -31,77 +51,64 @@ def _site() -> SimpleNamespace: default_language="en", chat_color_theme="light", chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, + code=f"code-{app_id[-6:]}", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, ) + db_session.add(site) + db_session.commit() + return site -# --------------------------------------------------------------------------- -# AppSiteApi -# --------------------------------------------------------------------------- class TestAppSiteApi: @patch("controllers.web.site.FeatureService.get_features") - @patch("controllers.web.site.db") - def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_features.return_value = SimpleNamespace(can_replace_logo=False) - site_obj = _site() - mock_db.session.scalar.return_value = site_obj - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): result = AppSiteApi().get(app_model, end_user) - # marshal_with serializes AppSiteInfo to a dict - assert result["app_id"] == "app-1" + assert result["app_id"] == app_model.id assert result["plan"] == "basic" assert result["enable_site"] is True - @patch("controllers.web.site.db") - def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_db.session.scalar.return_value = None - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) end_user = SimpleNamespace(id="eu-1") with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) - @patch("controllers.web.site.db") - def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + @patch("controllers.web.site.FeatureService.get_features") + def test_archived_tenant_raises_forbidden( + self, mock_features, app: Flask, db_session_with_containers: Session + ) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - from models.account import TenantStatus - - mock_db.session.scalar.return_value = _site() - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.ARCHIVE, - plan="basic", - custom_config_dict={}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) -# --------------------------------------------------------------------------- -# AppSiteInfo -# --------------------------------------------------------------------------- class TestAppSiteInfo: def test_basic_fields(self) -> None: - tenant = _tenant() - site_obj = _site() + tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={}) + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) assert info.app_id == "app-1" @@ -118,7 +125,7 @@ class TestAppSiteInfo: plan="pro", custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, ) - site_obj = _site() + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) assert info.can_replace_logo is True diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index f14b2c0ae5..635cfee2da 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -24,7 +24,6 @@ def _patch_wraps(): patch("controllers.console.wraps.dify_config", dify_settings), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index c9ee67863d..c342e8994b 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,13 +22,6 @@ import uuid from time import time import pytest -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -40,6 +33,13 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -88,11 +88,11 @@ class TestPauseStatePersistenceLayerTestContainers: def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): """Set up test data for each test method using TestContainers.""" # Create test tenant and account - from models.account import Tenant, TenantAccountJoin, TenantAccountRole + from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -101,7 +101,7 @@ class TestPauseStatePersistenceLayerTestContainers: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 13caad799e..6524d6ce61 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,12 +4,11 @@ from __future__ import annotations from uuid import uuid4 -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -18,7 +17,15 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) from models.human_input import ( EmailExternalRecipientPayload, EmailMemberRecipientPayload, @@ -29,7 +36,7 @@ from models.human_input import ( def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) session.add(tenant) session.flush() @@ -39,7 +46,7 @@ def _create_tenant_with_members(session: Session, member_emails: list[str]) -> t email=email, name=f"Member {index}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) session.add(account) session.flush() diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 0a9b476afc..5aed230cd4 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,17 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_engine import GraphEngine @@ -16,20 +27,9 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun @@ -101,8 +101,8 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, + node_id="start", + config=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -116,8 +116,8 @@ def _build_graph( ], ) human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, + node_id="human", + config=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -130,8 +130,8 @@ def _build_graph( desc=None, ) end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, + node_id="end", + config=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -175,7 +175,7 @@ class TestHumanInputResumeNodeExecutionIntegration: def setup_test_data(self, db_session_with_containers: Session): tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -184,7 +184,7 @@ class TestHumanInputResumeNodeExecutionIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index cc72dc1cf3..35e41035df 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase): file_related_id = related_id return File( - id=str(uuid4()), # Generate new UUID for File.id + file_id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, remote_url=remote_url, diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index b745aed141..2fd289dfbc 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -6,7 +6,6 @@ from decimal import Decimal from uuid import uuid4 from graphon.nodes.human_input.entities import FormDefinition, UserAction - from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py index e922c19a5a..f10f519e25 100644 --- a/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py +++ b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py @@ -10,10 +10,10 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from sqlalchemy.orm import Session from core.workflow.file_reference import build_file_reference +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod from models.model import App, AppMode, Conversation, Message diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py index 4ca87de52d..6352f815df 100644 --- a/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py +++ b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py @@ -9,9 +9,9 @@ from collections.abc import Generator from uuid import uuid4 import pytest -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models.enums import ConversationFromSource, InvokeFrom from models.model import App, AppMode, Conversation, Message, Site from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType diff --git a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py index 957b7145d3..b325c97f7d 100644 --- a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -4,13 +4,13 @@ from typing import Any, NamedTuple import pytest import sqlalchemy as sa -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import exc as sa_exc from sqlalchemy import insert, select from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR +from graphon.model_runtime.entities.model_entities import ModelType from models.types import EnumText _USER_TABLE = "enum_text_users" diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index a68b3a08c7..641399c7f9 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 64c93ac07c..d9828e19c5 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,27 +2,31 @@ from __future__ import annotations +import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock from uuid import uuid4 import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from core.workflow.human_input_adapter import DeliveryMethodType +from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( + BackstageRecipientPayload, HumanInputDelivery, HumanInputForm, HumanInputFormRecipient, + RecipientType, ) from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -628,12 +632,12 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Integration tests for _build_human_input_required_reason using real DB models.""" - def test_builds_reason_from_form_definition( + def test_prefers_standalone_web_app_token_when_available( self, db_session_with_containers: Session, test_scope: _TestScope, ) -> None: - """Build the graph pause reason from the stored form definition.""" + """Use the public standalone web-app token for service API payloads.""" expiration_time = naive_utc_now() form_definition = FormDefinition( @@ -660,6 +664,40 @@ class TestBuildHumanInputRequiredReason: db_session_with_containers.add(form_model) db_session_with_containers.flush() + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + backstage_access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=backstage_access_token, + ) + console_access_token = secrets.token_urlsafe(8) + console_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.CONSOLE, + recipient_payload="{}", + access_token=console_access_token, + ) + web_app_access_token = secrets.token_urlsafe(8) + web_app_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + recipient_payload="{}", + access_token=web_app_access_token, + ) + db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient]) + db_session_with_containers.flush() # Create a pause so the reason has a valid pause_id workflow_run = _create_workflow_run( db_session_with_containers, @@ -688,8 +726,15 @@ class TestBuildHumanInputRequiredReason: # Refresh to ensure we have DB-round-tripped objects db_session_with_containers.refresh(form_model) db_session_with_containers.refresh(reason_model) + db_session_with_containers.refresh(backstage_recipient) + db_session_with_containers.refresh(console_recipient) + db_session_with_containers.refresh(web_app_recipient) - reason = _build_human_input_required_reason(reason_model, form_model) + reason = _build_human_input_required_reason( + reason_model, + form_model, + [backstage_recipient, console_recipient, web_app_recipient], + ) assert isinstance(reason, HumanInputRequired) assert reason.node_title == "Ask Name" @@ -697,3 +742,92 @@ class TestBuildHumanInputRequiredReason: assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" assert reason.resolved_default_values == {"name": "Alice"} + assert not hasattr(reason, "form_token") + + def test_falls_back_to_console_token_when_web_app_token_missing( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Use the console token only when no standalone web-app token exists.""" + + expiration_time = naive_utc_now() + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + + form_model = HumanInputForm( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_run_id=str(uuid4()), + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + db_session_with_containers.add(form_model) + db_session_with_containers.flush() + + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + backstage_access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=backstage_access_token, + ) + console_access_token = secrets.token_urlsafe(8) + console_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.CONSOLE, + recipient_payload="{}", + access_token=console_access_token, + ) + db_session_with_containers.add_all([backstage_recipient, console_recipient]) + db_session_with_containers.flush() + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.flush() + test_scope.state_keys.add(pause.state_object_key) + + reason_model = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=form_model.id, + node_id="node-1", + message="", + ) + db_session_with_containers.add(reason_model) + db_session_with_containers.commit() + + reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient]) + + assert isinstance(reason, HumanInputRequired) + assert not hasattr(reason, "form_token") diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index 7f44eb6ca3..54b7afc018 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -12,11 +12,11 @@ from decimal import Decimal from uuid import uuid4 import pytest -from graphon.nodes.human_input.entities import FormDefinition, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom @@ -271,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from core.workflow.human_input_compat import DeliveryMethodType + from core.workflow.human_input_adapter import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 22e0aa34ff..fa78f1c28b 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -7,6 +7,11 @@ from datetime import datetime from decimal import Decimal from uuid import uuid4 +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from graphon.entities import WorkflowNodeExecution from graphon.enums import ( BuiltinNodeTypes, @@ -14,11 +19,6 @@ from graphon.enums import ( WorkflowNodeExecutionStatus, ) from graphon.model_runtime.utils.encoders import jsonable_encoder -from sqlalchemy import Engine -from sqlalchemy.orm import Session, sessionmaker - -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from models.account import Account, Tenant from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index c5e9201ee3..d6f0657380 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index cc9596d15f..9a53ff087c 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -9,7 +9,7 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace -from models import AccountStatus, TenantAccountJoin +from models import AccountStatus, TenantAccountJoin, TenantStatus from services.account_service import AccountService, RegisterService, TenantService, TokenPair from services.errors.account import ( AccountAlreadyInTenantError, @@ -2851,7 +2851,7 @@ class TestRegisterService: interface_language="en-US", password=existing_pending_member_password, ) - existing_account.status = "pending" + existing_account.status = AccountStatus.PENDING db_session_with_containers.commit() @@ -2941,7 +2941,7 @@ class TestRegisterService: interface_language="en-US", password=already_in_tenant_password, ) - existing_account.status = "active" + existing_account.status = AccountStatus.ACTIVE db_session_with_containers.commit() @@ -3331,7 +3331,7 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "archive" + tenant.status = TenantStatus.ARCHIVE db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 4f3c0e4200..00a2f9a59f 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,7 +842,6 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType - from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 6c15587058..77ce28b999 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -9,7 +9,6 @@ from uuid import uuid4 import pytest import yaml from faker import Faker -from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -17,6 +16,7 @@ from core.trigger.constants import ( TRIGGER_WEBHOOK_NODE_TYPE, ) from extensions.ext_redis import redis_client +from graphon.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import AppModelConfig, IconType from services import app_dsl_service diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5b1a4790f5..3229693fd4 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -36,12 +36,19 @@ class TestAppGenerateService: ) as mock_message_based_generator, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config, + patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config, patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service - mock_billing_service.update_tenant_feature_plan_usage.return_value = { - "result": "success", - "history_id": "test_history_id", + mock_billing_service.quota_reserve.return_value = { + "reservation_id": "test-reservation-id", + "available": 100, + "reserved": 1, + } + mock_billing_service.quota_commit.return_value = { + "available": 99, + "reserved": 0, + "refunded": 0, } # Setup default mock returns for workflow service @@ -101,6 +108,8 @@ class TestAppGenerateService: mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100 mock_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_quota_dify_config.BILLING_ENABLED = False + mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 @@ -118,6 +127,7 @@ class TestAppGenerateService: "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, + "quota_dify_config": mock_quota_dify_config, "global_dify_config": mock_global_dify_config, } @@ -465,6 +475,7 @@ class TestAppGenerateService: # Set BILLING_ENABLED to True for this test mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True # Setup test arguments @@ -478,8 +489,10 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - # Verify billing service was called to consume quota - mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() + # Verify billing two-phase quota (reserve + commit) + billing = mock_external_service_dependencies["billing_service"] + billing.quota_reserve.assert_called_once() + billing.quota_commit.assert_called_once() def test_generate_with_invalid_app_mode( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index fa57dd4a6f..b695ae9fd9 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -658,15 +658,17 @@ class TestAppService: # Update app icon new_icon = "🌟" new_icon_background = "#FFD93D" + new_icon_type = "image" mock_current_user = create_autospec(Account, instance=True) mock_current_user.id = account.id mock_current_user.current_tenant_id = account.current_tenant_id with patch("services.app_service.current_user", mock_current_user): - updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background, new_icon_type) assert updated_app.icon == new_icon assert updated_app.icon_background == new_icon_background + assert str(updated_app.icon_type).lower() == new_icon_type assert updated_app.updated_by == account.id # Verify other fields remain unchanged diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py new file mode 100644 index 0000000000..0b7bd9ca64 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import sessionmaker + +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from graphon.variables import FloatVariable, IntegerVariable, StringVariable +from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource +from models.model import App, Conversation, EndUser +from models.workflow import ConversationVariable +from services.conversation_service import ConversationService +from services.errors.conversation import ( + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) + + +class ConversationServiceVariableIntegrationFactory: + @staticmethod + def create_app_and_account(db_session_with_containers): + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"conversation-variable-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.flush() + + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + return app, account + + @staticmethod + def create_end_user(db_session_with_containers, app: App): + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type=InvokeFrom.SERVICE_API.value, + external_user_id=f"external-{uuid4()}", + name=f"End User {uuid4()}", + is_anonymous=False, + session_id=f"session-{uuid4()}", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + return end_user + + @staticmethod + def create_conversation( + db_session_with_containers, + app: App, + user: Account | EndUser, + *, + name: str | None = None, + invoke_from: InvokeFrom = InvokeFrom.WEB_APP, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> Conversation: + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=name or f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=invoke_from.value, + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, + from_end_user_id=user.id if isinstance(user, EndUser) else None, + from_account_id=user.id if isinstance(user, Account) else None, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + if created_at is not None: + conversation.created_at = created_at + if updated_at is not None: + conversation.updated_at = updated_at + + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + return conversation + + @staticmethod + def create_variable( + db_session_with_containers, + *, + app: App, + conversation: Conversation, + variable: StringVariable | FloatVariable | IntegerVariable, + created_at: datetime | None = None, + ) -> ConversationVariable: + row = ConversationVariable.from_variable(app_id=app.id, conversation_id=conversation.id, variable=variable) + if created_at is not None: + row.created_at = created_at + row.updated_at = created_at + + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + +@pytest.fixture +def real_conversation_service_session_factory(flask_app_with_containers): + del flask_app_with_containers + real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + + with ( + patch("services.conversation_service.session_factory.create_session", side_effect=lambda: real_session_maker()), + patch("services.conversation_service.session_factory.get_session_maker", return_value=real_session_maker), + ): + yield + + +class TestConversationServiceVariables: + def test_get_conversational_variable_success( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + older_time = datetime(2024, 1, 1, 12, 0, 0) + newer_time = older_time + timedelta(minutes=5) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=older_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=newer_time, + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=None, + ) + + assert [item["id"] for item in result.data] == [first_variable.id, second_variable.id] + assert [item["name"] for item in result.data] == ["topic", "priority"] + assert result.limit == 10 + assert result.has_more is False + + def test_get_conversational_variable_with_last_id( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + base_time = datetime(2024, 1, 1, 9, 0, 0) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=base_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=base_time + timedelta(minutes=1), + ) + third_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="owner", value="alice"), + created_at=base_time + timedelta(minutes=2), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=first_variable.id, + ) + + assert [item["id"] for item in result.data] == [second_variable.id, third_variable.id] + assert result.has_more is False + + def test_get_conversational_variable_last_id_not_found_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=str(uuid4()), + ) + + def test_get_conversational_variable_sets_has_more( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + for index in range(3): + factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name=f"var_{index}", value=f"value_{index}"), + created_at=datetime(2024, 1, 1, 10, 0, index), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=2, + last_id=None, + ) + + assert len(result.data) == 2 + assert result.has_more is True + + def test_update_conversation_variable_success( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + ) + updated_at = datetime(2024, 1, 1, 15, 0, 0) + + with patch("services.conversation_service.naive_utc_now", return_value=updated_at): + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="support", + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == "support" + assert result["id"] == existing.id + assert result["value"] == "support" + assert result["updated_at"] == updated_at + + def test_update_conversation_variable_not_found_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=str(uuid4()), + user=account, + new_value="support", + ) + + def test_update_conversation_variable_type_mismatch_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=FloatVariable(id=str(uuid4()), name="score", value=1.5), + ) + + with pytest.raises(ConversationVariableTypeMismatchError, match="expects float"): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="wrong-type", + ) + + def test_update_conversation_variable_integer_number_compatibility( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=IntegerVariable(id=str(uuid4()), name="attempts", value=1), + ) + + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value=42, + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == 42 + assert result["value"] == 42 + + +class TestConversationServicePaginationWithContainers: + def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=str(uuid4()), + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + base_time = datetime(2024, 1, 1, 8, 0, 0) + newest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Newest", + updated_at=base_time + timedelta(minutes=2), + ) + middle = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Middle", + updated_at=base_time + timedelta(minutes=1), + ) + oldest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Oldest", + updated_at=base_time, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=middle.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert newest.id != middle.id + assert [conversation.id for conversation in result.data] == [oldest.id] + + def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha") + beta = factory.create_conversation(db_session_with_containers, app, account, name="Beta") + gamma = factory.create_conversation(db_session_with_containers, app, account, name="Gamma") + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=beta.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + sort_by="name", + ) + + assert alpha.id != beta.id + assert [conversation.id for conversation in result.data] == [gamma.id] + + def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + end_user_conversation = factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=end_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert account_conversation.id != end_user_conversation.id + assert [conversation.id for conversation in result.data] == [end_user_conversation.id] + + def test_pagination_filters_to_account_console_source(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert [conversation.id for conversation in result.data] == [account_conversation.id] diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index fb0adbbcc2..02ab3f8314 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,10 +3,10 @@ from uuid import uuid4 import pytest -from graphon.variables import StringVariable from sqlalchemy.orm import sessionmaker from extensions.ext_database import db +from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index f9bfa570cb..0de3c64c4f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py new file mode 100644 index 0000000000..2bec703f0c --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py @@ -0,0 +1,650 @@ +"""Testcontainers integration tests for SQL-backed DocumentService paths.""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus +from models.model import UploadFile +from services.dataset_service import DocumentService +from services.errors.account import NoPermissionError + +FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0) + + +class DocumentServiceIntegrationFactory: + @staticmethod + def create_dataset( + db_session_with_containers, + *, + tenant_id: str | None = None, + created_by: str | None = None, + name: str | None = None, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name or f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by or str(uuid4()), + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + dataset: Dataset, + name: str = "doc.txt", + position: int = 1, + tenant_id: str | None = None, + indexing_status: str = IndexingStatus.COMPLETED, + enabled: bool = True, + archived: bool = False, + is_paused: bool = False, + need_summary: bool = False, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + batch: str | None = None, + data_source_type: str = DataSourceType.UPLOAD_FILE, + data_source_info: dict | None = None, + created_by: str | None = None, + ) -> Document: + document = Document( + tenant_id=tenant_id or dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info or {}), + batch=batch or f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by or dataset.created_by, + doc_form=doc_form, + ) + document.indexing_status = indexing_status + document.enabled = enabled + document.archived = archived + document.is_paused = is_paused + document.need_summary = need_summary + if indexing_status == IndexingStatus.COMPLETED: + document.completed_at = FIXED_UPLOAD_CREATED_AT + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + @staticmethod + def create_upload_file( + db_session_with_containers, + *, + tenant_id: str, + created_by: str, + file_id: str | None = None, + name: str = "source.txt", + ) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.LOCAL, + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=created_by, + created_at=FIXED_UPLOAD_CREATED_AT, + used=False, + ) + if file_id: + upload_file.id = file_id + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +@pytest.fixture +def current_user_mock(): + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.id = str(uuid4()) + current_user.current_tenant_id = str(uuid4()) + current_user.current_role = None + yield current_user + + +def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_document(dataset.id, None) is None + + +def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset) + + result = DocumentService.get_document(dataset.id, document.id) + + assert result is not None + assert result.id == document.id + + +def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + result = DocumentService.get_documents_by_ids(dataset.id, []) + + assert result == [] + + +def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt") + doc_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + name="b.txt", + position=2, + ) + + result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id]) + + assert {document.id for document in result} == {doc_a.id, doc_b.id} + + +def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.update_documents_need_summary(dataset.id, []) == 0 + + +def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + paragraph_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + need_summary=True, + ) + qa_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + need_summary=True, + doc_form=IndexStructureType.QA_INDEX, + ) + + updated_count = DocumentService.update_documents_need_summary( + dataset.id, + [paragraph_doc.id, qa_doc.id], + need_summary=False, + ) + + db_session_with_containers.expire_all() + refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id) + refreshed_qa = db_session_with_containers.get(Document, qa_doc.id) + assert updated_count == 1 + assert refreshed_paragraph is not None + assert refreshed_qa is not None + assert refreshed_paragraph.need_summary is False + assert refreshed_qa.need_summary is True + + +def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url: + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True) + + +def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_type=DataSourceType.WEBSITE_CRAWL, + data_source_info={"url": "https://example.com"}, + ) + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={}, + ) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": 99}, + ) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + +def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": "missing-file"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + +def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result.id == upload_file.id + + +def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[str(uuid4())], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + tenant_id=str(uuid4()), + data_source_info={"upload_file_id": upload_file.id}, + ) + + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": str(uuid4())}, + ) + + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + mapping = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document_a.id, document_b.id], + tenant_id=dataset.tenant_id, + ) + + assert mapping[document_a.id].id == upload_file_a.id + assert mapping[document_b.id].id == upload_file_b.id + + +def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset( + current_user_mock, flask_app_with_containers +): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=str(uuid4()), + document_ids=[str(uuid4())], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + + with patch( + "services.dataset_service.DatasetService.check_dataset_permission", + side_effect=NoPermissionError("denied"), + ): + with pytest.raises(Forbidden, match="denied"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[document_b.id, document_a.id], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id] + assert download_name.endswith(".zip") + + +def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + enabled_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + enabled=False, + ) + + result = DocumentService.get_document_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [enabled_document.id] + + +def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + available_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.COMPLETED, + enabled=True, + archived=False, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.ERROR, + ) + + result = DocumentService.get_working_documents_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [available_document.id] + + +def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + error_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.ERROR, + ) + paused_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.PAUSED, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=3, + indexing_status=IndexingStatus.COMPLETED, + ) + + result = DocumentService.get_error_documents_by_dataset_id(dataset.id) + + assert {document.id for document in result} == {error_document.id, paused_document.id} + + +def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + batch = f"batch-{uuid4()}" + matching_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + batch=batch, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + tenant_id=str(uuid4()), + batch=batch, + ) + + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.current_tenant_id = dataset.tenant_id + result = DocumentService.get_batch_documents(dataset.id, batch) + + assert [document.id for document in result] == [matching_document.id] + + +def test_get_document_file_detail_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is not None + assert result.id == upload_file.id + + +def test_delete_document_emits_signal_and_commits(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.document_was_deleted.send") as signal_send: + DocumentService.delete_document(document) + + assert db_session_with_containers.get(Document, document.id) is None + signal_send.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id=upload_file.id, + ) + + +def test_delete_documents_ignores_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, []) + + delay.assert_not_called() + + +def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX + db_session_with_containers.commit() + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, [document_a.id, document_b.id]) + + assert db_session_with_containers.get(Document, document_a.id) is None + assert db_session_with_containers.get(Document, document_b.id) is None + delay.assert_called_once() + args = delay.call_args.args + assert args[0] == [document_a.id, document_b.id] + assert args[1] == dataset.id + assert set(args[3]) == {upload_file_a.id, upload_file_b.id} + + +def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3) + + assert DocumentService.get_documents_position(dataset.id) == 4 + + +def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_documents_position(dataset.id) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py new file mode 100644 index 0000000000..1b4179c9c7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -0,0 +1,613 @@ +"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetAutoDisableLog, + DatasetCollectionBinding, + DatasetPermission, + DatasetPermissionEnum, +) +from models.enums import DataSourceType +from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account, tenant + + @staticmethod + def create_account_in_tenant( + db_session_with_containers: Session, + tenant: Tenant, + role: TenantAccountRole = TenantAccountRole.EDITOR, + ) -> Account: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account + + @staticmethod + def create_dataset( + db_session_with_containers: Session, + *, + tenant_id: str, + created_by: str, + name: str | None = None, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, + enable_api: bool = True, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=name or f"dataset-{uuid4()}", + description="desc", + data_source_type=DataSourceType.UPLOAD_FILE, + indexing_technique=indexing_technique, + created_by=created_by, + provider="vendor", + permission=permission, + retrieval_model={"top_k": 2}, + ) + dataset.enable_api = enable_api + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_dataset_permission( + db_session_with_containers: Session, + *, + dataset_id: str, + tenant_id: str, + account_id: str, + ) -> DatasetPermission: + permission = DatasetPermission( + dataset_id=dataset_id, + tenant_id=tenant_id, + account_id=account_id, + has_permission=True, + ) + db_session_with_containers.add(permission) + db_session_with_containers.commit() + return permission + + @staticmethod + def create_app_dataset_join( + db_session_with_containers: Session, + *, + dataset_id: str, + ) -> AppDatasetJoin: + join = AppDatasetJoin( + app_id=str(uuid4()), + dataset_id=dataset_id, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return join + + @staticmethod + def create_collection_binding( + db_session_with_containers: Session, + *, + provider_name: str, + model_name: str, + collection_type: str = "dataset", + ) -> DatasetCollectionBinding: + binding = DatasetCollectionBinding( + provider_name=provider_name, + model_name=model_name, + collection_name=f"collection_{uuid4().hex}", + type=collection_type, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + return binding + + @staticmethod + def create_auto_disable_log( + db_session_with_containers: Session, + *, + tenant_id: str, + dataset_id: str, + document_id: str, + ) -> DatasetAutoDisableLog: + log = DatasetAutoDisableLog( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + ) + db_session_with_containers.add(log) + db_session_with_containers.commit() + return log + + +class TestDatasetServicePermissionsAndLifecycle: + def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session): + owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + + result = DatasetService.delete_dataset(str(uuid4()), user=owner) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, user=owner) + + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + send_deleted_signal.assert_called_once_with(dataset) + + def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_app_dataset_join( + db_session_with_containers, + dataset_id=dataset.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is True + + def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is False + + def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, outsider) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session): + creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=creator.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_allows_partial_team_member_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_allows_partial_team_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=operator.id, + ) + + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status(str(uuid4()), True) + + def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(id=None)): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + now = datetime(2026, 4, 14, 18, 0, 0) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.naive_utc_now", return_value=now), + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + db_session_with_containers.refresh(dataset) + assert dataset.enable_api is True + assert dataset.updated_by == owner.id + assert dataset.updated_at == now + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(str(uuid4())) + + assert result == {"document_ids": [], "count": 0} + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(dataset.id) + + assert result["count"] == 2 + assert len(result["document_ids"]) == 2 + + +class TestDatasetCollectionBindingServiceIntegration: + def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result.id == binding.id + + def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session): + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model") + + persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id) + assert persisted is not None + assert persisted.provider_name == "provider" + assert persisted.model_name == "missing-model" + assert persisted.type == "dataset" + assert persisted.collection_name + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id) + + assert result.id == binding.id + + +class TestDatasetPermissionServiceIntegration: + def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_a.id, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_b.id, + ) + + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + assert set(result) == {member_a.id, member_b.id} + + def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=stale_member.id, + ) + + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + [{"user_id": member_a.id}, {"user_id": member_b.id}], + ) + + permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id} + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetPermissionService.clear_partial_member_list(dataset.id) + + remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert remaining == [] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 2a2d86a8a6..ac0483a45d 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -3,11 +3,18 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from graphon.model_runtime.entities.model_entities import ModelType +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings from models.enums import DataSourceType from services.dataset_service import DatasetService @@ -26,12 +33,12 @@ class DatasetUpdateTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index c8f04e9215..fe426ae516 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index b3e7dd2a59..315936d721 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -274,6 +274,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = True mock_config.ALLOW_REGISTER = False mock_config.ALLOW_CREATE_WORKSPACE = False mock_config.MAIL_TYPE = "smtp" @@ -298,6 +299,7 @@ class TestFeatureService: # Verify authentication settings assert result.enable_email_code_login is True assert result.enable_email_password_login is False + assert result.enable_collaboration_mode is True assert result.is_allow_register is False assert result.is_allow_create_workspace is False @@ -401,6 +403,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = False mock_config.ALLOW_REGISTER = True mock_config.ALLOW_CREATE_WORKSPACE = True mock_config.MAIL_TYPE = "smtp" @@ -422,6 +425,7 @@ class TestFeatureService: assert result.enable_email_code_login is True assert result.enable_email_password_login is True assert result.enable_social_oauth_login is False + assert result.enable_collaboration_mode is False assert result.is_allow_register is True assert result.is_allow_create_workspace is True assert result.is_email_setup is True diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index d82933ccb9..3dcd6586e2 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -13,6 +13,12 @@ from models.model import App, Conversation, Message from services.feedback_service import FeedbackService +def _execute_result(rows): + result = mock.Mock() + result.all.return_value = rows + return result + + class TestFeedbackService: """Test FeedbackService methods.""" @@ -81,25 +87,17 @@ class TestFeedbackService: def test_export_feedbacks_csv_format(self, mock_db_session, sample_data): """Test exporting feedback data in CSV format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test CSV export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -120,25 +118,17 @@ class TestFeedbackService: def test_export_feedbacks_json_format(self, mock_db_session, sample_data): """Test exporting feedback data in JSON format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test JSON export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -157,25 +147,17 @@ class TestFeedbackService: def test_export_feedbacks_with_filters(self, mock_db_session, sample_data): """Test exporting feedback with various filters.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test with filters result = FeedbackService.export_feedbacks( @@ -193,17 +175,7 @@ class TestFeedbackService: def test_export_feedbacks_no_data(self, mock_db_session, sample_data): """Test exporting feedback when no data exists.""" - - # Setup mock query result with no data - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result([]) result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -251,24 +223,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - long_message, - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + long_message, + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -309,24 +274,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - chinese_feedback, - chinese_message, - sample_data["conversation"], - sample_data["app"], - None, # No account for user feedback - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + chinese_feedback, + chinese_message, + sample_data["conversation"], + sample_data["app"], + None, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -339,32 +297,24 @@ class TestFeedbackService: def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data): """Test that rating emojis are properly formatted in export.""" - - # Setup mock query result with both like and dislike feedback - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ), - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ), - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ), + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ), + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index c46b8fba0b..80f9083e81 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,15 +3,15 @@ import uuid from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, ) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 0f252515f7..ed75363f3b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -5,17 +5,17 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) +from graphon.runtime import VariablePool from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 2340dd2a03..cd63d3ad6c 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,11 +8,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from graphon.file import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index ba926bf675..8955a3b5f2 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,11 +405,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -644,9 +643,8 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from graphon.model_runtime.entities.common_entities import I18nObject - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + from graphon.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_schedule_service.py b/api/tests/test_containers_integration_tests/services/test_schedule_service.py new file mode 100644 index 0000000000..87f3306258 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_schedule_service.py @@ -0,0 +1,387 @@ +"""Testcontainers integration tests for schedule service SQL-backed behavior.""" + +from datetime import datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate +from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError +from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.trigger import WorkflowSchedulePlan +from services.errors.account import AccountNotFoundError +from services.trigger.schedule_service import ScheduleService + + +class ScheduleServiceIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_schedule_plan( + db_session_with_containers: Session, + *, + tenant_id: str, + app_id: str | None = None, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", + next_run_at: datetime | None = None, + ) -> WorkflowSchedulePlan: + schedule = WorkflowSchedulePlan( + tenant_id=tenant_id, + app_id=app_id or str(uuid4()), + node_id=node_id, + cron_expression=cron_expression, + timezone=timezone, + next_run_at=next_run_at, + ) + db_session_with_containers.add(schedule) + db_session_with_containers.commit() + return schedule + + +def _cron_workflow( + *, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", +): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": node_id, + "data": { + "type": "trigger-schedule", + "mode": "cron", + "cron_expression": cron_expression, + "timezone": timezone, + }, + } + ] + } + ) + + +def _no_schedule_workflow(): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": "node-1", + "data": {"type": "llm"}, + } + ] + } + ) + + +class TestScheduleServiceIntegration: + def test_create_schedule_persists_schedule(self, db_session_with_containers: Session): + account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + expected_next_run = datetime(2026, 1, 1, 10, 30, 0) + config = ScheduleConfig( + node_id="start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + schedule = ScheduleService.create_schedule( + session=db_session_with_containers, + tenant_id=tenant.id, + app_id=str(uuid4()), + config=config, + ) + + persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) + assert persisted is not None + assert persisted.tenant_id == tenant.id + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + cron_expression="30 10 * * *", + timezone="UTC", + ) + expected_next_run = datetime(2026, 1, 2, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate( + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + db_session_with_containers.refresh(updated) + assert updated.cron_expression == "0 12 * * *" + assert updated.timezone == "America/New_York" + assert updated.next_run_at == expected_next_run + + def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + initial_next_run = datetime(2026, 1, 1, 10, 0, 0) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + next_run_at=initial_next_run, + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + calls: list[tuple] = [] + + def _track(*args, **kwargs): + calls.append((args, kwargs)) + return datetime(2026, 1, 9, 10, 0, 0) + + monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + db_session_with_containers.refresh(updated) + assert updated.node_id == "node-new" + assert updated.next_run_at == initial_next_run + assert calls == [] + + def test_update_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + def test_delete_schedule_removes_row(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + db_session_with_containers.commit() + + assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None + + def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session): + owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.OWNER, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == owner.id + + def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session): + admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.ADMIN, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == admin.id + + def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + missing_account_id = str(uuid4()) + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=missing_account_id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=missing_account_id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=tenant.id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + expected_next_run = datetime(2026, 1, 3, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + + db_session_with_containers.refresh(schedule) + assert result == expected_next_run + assert schedule.next_run_at == expected_next_run + + def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + +class TestSyncScheduleFromWorkflowIntegration: + def test_sync_schedule_create_new(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + expected_next_run = datetime(2026, 1, 4, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow(), + ) + + assert result is not None + persisted = db_session_with_containers.execute( + select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id) + ).scalar_one() + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_update_existing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + node_id="old-start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + existing_id = existing.id + expected_next_run = datetime(2026, 1, 5, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow( + node_id="start", + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + assert result is not None + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id) + assert persisted is not None + assert persisted.node_id == "start" + assert persisted.cron_expression == "0 12 * * *" + assert persisted.timezone == "America/New_York" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + ) + existing_id = existing.id + + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_no_schedule_workflow(), + ) + + assert result is None + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py new file mode 100644 index 0000000000..85ce3a6ba6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from enums.quota_type import QuotaType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import AppTriggerStatus, AppTriggerType +from models.model import App +from models.trigger import AppTrigger, WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger.webhook_service import WebhookService + + +class WebhookServiceRelationshipFactory: + @staticmethod + def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]: + account = Account( + name=f"Account {uuid4()}", + email=f"webhook-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App: + app = App( + tenant_id=tenant.id, + name=f"Webhook App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + @staticmethod + def create_workflow( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_ids: list[str], + version: str, + ) -> Workflow: + graph = { + "nodes": [ + { + "id": node_id, + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "title": f"Webhook {node_id}", + "method": "post", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + "status_code": 200, + "response_body": '{"status": "ok"}', + "timeout": 30, + }, + } + for node_id in node_ids + ], + "edges": [], + } + + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + version=version, + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + return workflow + + @staticmethod + def create_webhook_trigger( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_id: str, + webhook_id: str | None = None, + ) -> WorkflowWebhookTrigger: + webhook_trigger = WorkflowWebhookTrigger( + app_id=app.id, + node_id=node_id, + tenant_id=app.tenant_id, + webhook_id=webhook_id or uuid4().hex[:24], + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + return webhook_trigger + + @staticmethod + def create_app_trigger( + db_session_with_containers: Session, + *, + app: App, + node_id: str, + status: AppTriggerStatus, + ) -> AppTrigger: + app_trigger = AppTrigger( + tenant_id=app.tenant_id, + app_id=app.id, + node_id=node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + provider_name="webhook", + title=f"Webhook {node_id}", + status=status, + ) + db_session_with_containers.add(app_trigger) + db_session_with_containers.commit() + return app_trigger + + +class TestWebhookServiceLookupWithContainers: + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED + ) + + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED + ) + + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED + ) + + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["published-node"], + version="2026-04-14.001", + ) + draft_workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["debug-node"], + version=Workflow.VERSION_DRAFT, + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="debug-node" + ) + + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + webhook_trigger.webhook_id, + is_debug=True, + ) + + assert got_trigger.id == webhook_trigger.id + assert got_workflow.id == draft_workflow.id + assert got_node_config["id"] == "debug-node" + + +class TestWebhookServiceTriggerExecutionWithContainers: + def test_trigger_workflow_execution_triggers_async_workflow_successfully( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + end_user = SimpleNamespace(id=str(uuid4())) + webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"} + + quota_charge = MagicMock() + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=end_user, + ), + patch( + "services.trigger.webhook_service.QuotaService.reserve", + return_value=quota_charge, + ) as mock_reserve, + patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger, + ): + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + mock_reserve.assert_called_once() + reserve_args = mock_reserve.call_args.args + assert reserve_args[0] == QuotaType.TRIGGER + assert reserve_args[1] == webhook_trigger.tenant_id + quota_charge.commit.assert_called_once() + mock_trigger.assert_called_once() + trigger_args = mock_trigger.call_args.args + assert trigger_args[1] is end_user + assert trigger_args[2].workflow_id == workflow.id + assert trigger_args[2].root_node_id == webhook_trigger.node_id + + def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=SimpleNamespace(id=str(uuid4())), + ), + patch( + "services.trigger.webhook_service.QuotaService.reserve", + side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1), + ), + patch( + "services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited" + ) as mock_mark_rate_limited, + ): + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_mark_rate_limited.assert_called_once_with(tenant.id) + + def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + side_effect=RuntimeError("boom"), + ), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_logger_exception.assert_called_once() + + +class TestWebhookServiceRelationshipSyncWithContainers: + def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)] + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT + ) + + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_raises_when_lock_not_acquired( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = False + + with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock): + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + stale_trigger = factory.create_webhook_trigger( + db_session_with_containers, + app=app, + account=account, + node_id="node-stale", + webhook_id="stale-webhook-id-000001", + ) + stale_trigger_id = stale_trigger.id + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-new"], + version=Workflow.VERSION_DRAFT, + ) + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + db_session_with_containers.expire_all() + records = db_session_with_containers.scalars( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id) + ).all() + + assert [record.node_id for record in records] == ["node-new"] + assert records[0].webhook_id == "new-webhook-id-000001" + assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None + + def test_sync_webhook_relationships_sets_redis_cache_for_new_record( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-cache"], + version=Workflow.VERSION_DRAFT, + ) + cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache" + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key) + assert cached_payload is not None + assert cached_payload["node_id"] == "node-cache" + assert cached_payload["webhook_id"] == "cache-webhook-id-00001" + + def test_sync_webhook_relationships_logs_when_lock_release_fails( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + with ( + patch("services.trigger.webhook_service.redis_client.lock", return_value=lock), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + WebhookService.sync_webhook_relationships(app, workflow) + + mock_logger_exception.assert_called_once() + + +def _read_cache(cache_key: str) -> dict[str, str] | None: + from extensions.ext_redis import redis_client + + cached = redis_client.get(cache_key) + if not cached: + return None + if isinstance(cached, bytes): + cached = cached.decode("utf-8") + return json.loads(cached) + + +WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 749c6fff5b..1e57b5603d 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 0c281c8c33..86cf2327c7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker -from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index d3e765055a..af83adaae0 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -1,3 +1,5 @@ +import inspect +import json from unittest.mock import patch import pytest @@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.errors import ApiToolProviderNotFoundError +from core.tools.tool_label_manager import ToolLabelManager from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -590,30 +594,204 @@ class TestApiToolManageService: with pytest.raises(ValueError, match="you have not added provider"): ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") - def test_update_api_tool_provider_not_found( + def test_update_api_tool_provider_success( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): - """Test update raises ValueError when original provider not found.""" fake = Faker() + + # Firmware fix for cache.delete() in update flow + mock_encrypter = mock_external_service_dependencies["encrypter"] + from unittest.mock import MagicMock + + mock_cache = MagicMock() + mock_cache.delete.return_value = None + mock_encrypter.return_value = (mock_encrypter, mock_cache) + + # Get fake account and tenant account, tenant = self._create_test_account_and_tenant( db_session_with_containers, mock_external_service_dependencies ) - with pytest.raises(ValueError, match="does not exists"): - ApiToolManageService.update_api_tool_provider( + # original provider name + original_name = "original-provider" + + # Create original provider + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=original_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="", + custom_disclaimer="", + labels=["old-label"], + ) + + # new provide name and new labels for update + new_name = "updated-provider" + new_labels = ["new-label-1", "new-label-2"] + + # Reset mock history so assertions focus on update path only + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + + # Act: Update the provider with new values + result = ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + # new provider name - changed 1 + provider_name=new_name, + original_provider=original_name, + # new icon - changed 2 + icon={"type": "emoji", "value": "🚀"}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + # new privacy policy - changed 3 + privacy_policy="https://new-policy.com", + # new custom disclaimer - changed 4 + custom_disclaimer="New disclaimer", + # new labels - changed 5 (However, we will not verify this, not this layer responsibility.) + labels=new_labels, + ) + + # Assert: Verify the result + assert result == {"result": "success"} + + # Get the updated provider from the database + updated_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name) + .first() + ) + + # Verify the provider was updated successfully + assert updated_provider is not None + + # Manually refresh to keep object detachment + db_session_with_containers.refresh(updated_provider) + # Verify all the updated fields + # - changed 1 + assert updated_provider.name == new_name + # - changed 2 + icon_data = json.loads(updated_provider.icon) + assert icon_data["type"] == "emoji" + assert icon_data["value"] == "🚀" + # - changed 3 + assert updated_provider.privacy_policy == "https://new-policy.com" + # - changed 4 + assert updated_provider.custom_disclaimer == "New disclaimer" + + # Verify old provider name no longer exists after rename + original_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name) + .first() + ) + assert original_provider is None + + # Verify update flow calls critical collaborators + mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + mock_external_service_dependencies["encrypter"].assert_called_once() + mock_cache.delete.assert_called_once() + + # Deeply verify on session propagation of labels update logics: + # Since in refactoring, we pass session down to label manager to keep atomicity. + # The assertion here is to verify this. + sig = inspect.signature(ToolLabelManager.update_tool_labels) + args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args + bound_args = sig.bind(*args, **kwargs) + passed_session = bound_args.arguments.get("session") + # Ensure the type: Session + assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}" + assert passed_session is not None, ( + "Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider" + ) + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update raises ValueError when original provider not found. + + This test verifies: + - Proper error when trying to update a non-existing original provider + - No accidental upsert/new provider creation + - No external dependency invocation on early failure path + """ + # Arrange: Create test account and tenant + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Keep an existing provider in DB to ensure unrelated data remains unchanged + existing_provider_name = "existing-provider" + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=existing_provider_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="https://existing-policy.com", + custom_disclaimer="Existing disclaimer", + labels=["existing-label"], + ) + + # Reset mock history so assertions focus on update failure path only + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + + # Act & Assert: Verify update fails with clear error message + target_new_name = "new-provider-name" + missing_original_name = "missing-original-provider" + with pytest.raises(ApiToolProviderNotFoundError) as exc_info: + _ = ApiToolManageService.update_api_tool_provider( user_id=account.id, tenant_id=tenant.id, - provider_name="new-name", - original_provider="nonexistent", - icon={}, + provider_name=target_new_name, + original_provider=missing_original_name, + icon={"type": "emoji", "value": "🚀"}, credentials={"auth_type": "none"}, _schema_type=ApiProviderSchemaType.OPENAPI, schema=self._create_test_openapi_schema(), - privacy_policy=None, - custom_disclaimer="", - labels=[], + privacy_policy="https://new-policy.com", + custom_disclaimer="New disclaimer", + labels=["new-label"], ) + error = exc_info.value + assert error.provider_name == missing_original_name + assert error.tenant_id == tenant.id + assert error.error_code == "api_tool_provider_not_found" + + # Assert: Existing provider should remain unchanged + existing_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name) + .first() + ) + assert existing_provider is not None + assert existing_provider.name == existing_provider_name + + # Assert: No new provider should be created + unexpected_new_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name) + .first() + ) + assert unexpected_new_provider is None + + # Assert: Early failure should skip all downstream external interactions + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called() + mock_external_service_dependencies["encrypter"].assert_not_called() + mock_external_service_dependencies["provider_controller"].from_db.assert_not_called() + def test_update_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index ce2fd2eeb1..ce5c2bd162 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -5,9 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -21,6 +18,9 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 7c43bf676b..4dab895135 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4b04c1accb..fcc15aad42 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Verify logs exist before processing - existing_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + existing_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(existing_logs) == 2 # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify auto disable logs were deleted - remaining_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + remaining_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(remaining_logs) == 0 # Verify index processing occurred normally diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 6cbbe43137..e29ca7ebab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,6 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType @@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_with_image_files( @@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Verify that the task completed successfully by checking the log output @@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify database cleanup db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( @@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document_id).limit(1) + ) assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( @@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_multiple_documents( @@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( @@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None except Exception as e: @@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( @@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Verify initial state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1)) + is not None + ) # Store original IDs for verification document_id = document.id @@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify final database state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id) + ) + == 0 + ) + assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index f9ae33b32f..05827112d4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that segments were created - segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(segments) == 3 # Verify segment content and metadata @@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created (since dataset doesn't exist) - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify no documents were modified - documents = db_session_with_containers.query(Document).all() + documents = db_session_with_containers.scalars(select(Document)).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( @@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) db_session_with_containers.refresh(dataset) - segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments_for_dataset = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( @@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) + ).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( @@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions - all_segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + all_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(all_segments) == 6 # 3 existing + 3 new # Verify position ordering diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1dd37fbc92..32bc2fc0bd 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -52,18 +53,18 @@ class TestCleanDatasetTask: from extensions.ext_redis import redis_client # Clear all test data using the provided session fixture - db_session_with_containers.query(DatasetMetadataBinding).delete() - db_session_with_containers.query(DatasetMetadata).delete() - db_session_with_containers.query(AppDatasetJoin).delete() - db_session_with_containers.query(DatasetQuery).delete() - db_session_with_containers.query(DatasetProcessRule).delete() - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DatasetMetadataBinding)) + db_session_with_containers.execute(delete(DatasetMetadata)) + db_session_with_containers.execute(delete(AppDatasetJoin)) + db_session_with_containers.execute(delete(DatasetQuery)) + db_session_with_containers.execute(delete(DatasetProcessRule)) + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -302,28 +303,40 @@ class TestCleanDatasetTask: # Verify results # Check that dataset-related data was cleaned up - documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() assert len(documents) == 0 - segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments) == 0 # Check that metadata and bindings were cleaned up - metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(metadata) == 0 - bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(bindings) == 0 # Check that process rules and queries were cleaned up - process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() + process_rules = db_session_with_containers.scalars( + select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id) + ).all() assert len(process_rules) == 0 - queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() + queries = db_session_with_containers.scalars( + select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id) + ).all() assert len(queries) == 0 # Check that app dataset joins were cleaned up - app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() + app_joins = db_session_with_containers.scalars( + select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id) + ).all() assert len(app_joins) == 0 # Verify index processor was called @@ -414,24 +427,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that metadata and bindings were cleaned up - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify index processor was called @@ -485,12 +506,14 @@ class TestCleanDatasetTask: # Check that all data was cleaned up - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 - remaining_segments = ( - db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() - ) + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Recreate data for next test case @@ -538,11 +561,15 @@ class TestCleanDatasetTask: # Verify results - even with vector cleanup failure, documents and segments should be deleted # Check that documents were still deleted despite vector cleanup failure - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that segments were still deleted despite vector cleanup failure - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Verify that index processor was called and failed @@ -622,18 +649,22 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all image files were deleted from database image_file_ids = [f.id for f in image_files] - remaining_image_files = ( - db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() - ) + remaining_image_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(image_file_ids)) + ).all() assert len(remaining_image_files) == 0 # Verify that storage.delete was called for each image file @@ -738,24 +769,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that all metadata and bindings were deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify performance expectations @@ -826,7 +865,9 @@ class TestCleanDatasetTask: # Check that upload file was still deleted from database despite storage failure # Note: When storage operations fail, the upload file may not be deleted # This demonstrates that the cleanup process continues even with storage errors - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file.id) + ).all() # The upload file should still be deleted from the database even if storage cleanup fails # However, this depends on the specific implementation of clean_dataset_task if len(remaining_files) > 0: @@ -976,19 +1017,27 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file_id) + ).all() assert len(remaining_files) == 0 # Check that all metadata was deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 # Verify that storage.delete was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 926c839c8b..7e5c374b5d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -11,6 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import ColumnElement, func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -20,6 +22,14 @@ from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password +def _count_documents(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(Document).where(condition)) or 0 + + +def _count_segments(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(DocumentSegment).where(condition)) or 0 + + class TestCleanNotionDocumentTask: """Integration tests for clean_notion_document_task using testcontainers.""" @@ -145,24 +155,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(document_ids)) == 3 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 6 # Execute cleanup task clean_notion_document_task(document_ids, dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 0 # Verify index processor was called mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value @@ -322,12 +322,7 @@ class TestCleanNotionDocumentTask: # The task properly handles various index types and document configurations. # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Reset mock for next iteration mock_index_processor_factory.reset_mock() @@ -410,10 +405,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies that segments without index_node_ids # are properly deleted from the database. @@ -499,11 +491,8 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 10 - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == 5 + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 10 # Clean up only first 3 documents documents_to_clean = [doc.id for doc in documents[:3]] @@ -513,22 +502,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(documents_to_clean, dataset.id) # Verify only specified documents' segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(documents_to_clean)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(documents_to_clean)) == 0 # Verify remaining documents and segments are intact remaining_docs = [doc.id for doc in documents[3:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 4 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 4 # Note: This test successfully verifies partial document cleanup operations. # The database operations work correctly, isolating only the specified documents. @@ -612,31 +591,36 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all segments exist before cleanup - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 4 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 4 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies database operations. # IndexProcessor verification would require more sophisticated mocking. - def test_clean_notion_document_task_database_transaction_rollback( + def test_clean_notion_document_task_continues_when_index_processor_fails( self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies ): """ - Test cleanup task behavior when database operations fail. + Index processor failure (e.g. transient billing API error propagated via + ``FeatureService`` when ``Vector(dataset)`` lazily resolves the embedding + model) must NOT abort the cleanup task. The Document rows have already + been hard-deleted in the first session block before vector cleanup runs, + so any uncaught exception escaping the task would strand + ``DocumentSegment`` rows in PG with no parent ``Document``. - This test verifies that the task properly handles database errors - and maintains data consistency. + Contract: the task swallows the index_processor exception, logs it, and + proceeds to delete the segments — leaving PG consistent. (Vector orphans, + if any, can be reaped later by an offline scanner.) + + Regression guard for the production incident where ``clean_document_task`` + / ``clean_notion_document_task`` failed with + ``ValueError("Unable to retrieve billing information...")`` and left + tens of thousands of orphan segments per affected tenant. """ fake = Faker() @@ -699,17 +683,28 @@ class TestCleanNotionDocumentTask: db_session_with_containers.add(segment) db_session_with_containers.commit() - # Mock index processor to raise an exception + # Simulate the production failure mode: index_processor.clean() raises a + # ValueError mirroring ``BillingService._send_request`` returning non-200. mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - mock_index_processor.clean.side_effect = Exception("Index processor error") + mock_index_processor.clean.side_effect = ValueError( + "Unable to retrieve billing information. Please try again later or contact support." + ) - # Execute cleanup task - current implementation propagates the exception - with pytest.raises(Exception, match="Index processor error"): - clean_notion_document_task([document.id], dataset.id) + # Execute cleanup task — must NOT raise even though clean() raises. + # Before the safety-net wrapper this would have re-raised the ValueError, + # aborting the task and leaving DocumentSegment stranded in PG. + clean_notion_document_task([document.id], dataset.id) - # Note: This test demonstrates the task's error handling capability. - # Even with external service errors, the database operations complete successfully. - # In a production environment, proper error handling would determine transaction rollback behavior. + # Vector cleanup was attempted exactly once. + mock_index_processor.clean.assert_called_once() + + # The crucial assertion: despite the index processor failure, the + # final session block (line 51-52, ``DELETE FROM document_segments``) + # still ran and committed. This is what the wrapper buys us — without + # it the production incident left tens of thousands of orphan segments + # per affected tenant. Aligns with the assertion shape used by the + # happy-path test (``test_clean_notion_document_task_success``). + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 def test_clean_notion_document_task_with_large_number_of_documents( self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies @@ -794,12 +789,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == num_documents assert ( - db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() - == num_documents - ) - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == num_documents * num_segments_per_doc ) @@ -808,10 +800,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies bulk document cleanup operations. # The database efficiently handles large-scale deletions. @@ -906,8 +895,8 @@ class TestCleanNotionDocumentTask: # Verify all data exists before cleanup # Note: There may be documents from previous tests, so we check for at least 3 - assert db_session_with_containers.query(Document).count() >= 3 - assert db_session_with_containers.query(DocumentSegment).count() >= 9 + assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3 + assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9 # Clean up documents from only the first dataset target_dataset = datasets[0] @@ -918,22 +907,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([target_document.id], target_dataset.id) # Verify only documents' segments from target dataset are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == target_document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == target_document.id) == 0 # Verify documents from other datasets remain intact remaining_docs = [doc.id for doc in all_documents[1:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 6 # Note: This test successfully verifies multi-tenant isolation. # Only documents from the target dataset are affected, maintaining tenant separation. @@ -1028,11 +1007,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len( - document_statuses - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == len(document_statuses) assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == len(document_statuses) * 2 ) @@ -1041,10 +1018,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies cleanup of documents in various states. # All documents are deleted regardless of their indexing status. @@ -1142,20 +1116,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 3 - ) + assert _count_documents(db_session_with_containers, Document.id == document.id) == 1 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 3 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies cleanup of documents with rich metadata. # The task properly handles complex document structures and metadata fields. diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 13ea94348a..684097851b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration: def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: """Return the latest persisted document state.""" - return db_session_with_containers.query(Document).where(Document.id == document_id).first() + return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index d457b59d58..48fec441c5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -221,7 +222,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load method was called @@ -322,7 +325,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "update") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor clean and load methods were called @@ -431,7 +436,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify that no index processor load was called since no segments exist @@ -564,7 +571,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to error - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.ERROR assert "Test exception during indexing" in updated_document.error @@ -635,7 +644,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type @@ -711,7 +722,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type @@ -815,7 +828,9 @@ class TestDealDatasetVectorIndexTask: # Verify all documents were processed for document in documents: - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load was called multiple times @@ -917,7 +932,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify final document status - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( @@ -1027,12 +1044,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only enabled document was processed - updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() + updated_enabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == enabled_document.id).limit(1) + ) assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED # Verify disabled document status remains unchanged - updated_disabled_document = ( - db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() + updated_disabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == disabled_document.id).limit(1) ) assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1148,12 +1167,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only active document was processed - updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() + updated_active_document = db_session_with_containers.scalar( + select(Document).where(Document.id == active_document.id).limit(1) + ) assert updated_active_document.indexing_status == IndexingStatus.COMPLETED # Verify archived document status remains unchanged - updated_archived_document = ( - db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() + updated_archived_document = db_session_with_containers.scalar( + select(Document).where(Document.id == archived_document.id).limit(1) ) assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1269,14 +1290,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only completed document was processed - updated_completed_document = ( - db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() + updated_completed_document = db_session_with_containers.scalar( + select(Document).where(Document.id == completed_document.id).limit(1) ) assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED # Verify incomplete document status remains unchanged - updated_incomplete_document = ( - db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() + updated_incomplete_document = db_session_with_containers.scalar( + select(Document).where(Document.id == incomplete_document.id).limit(1) ) assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 3e9a0c8f7f..6e03bd9351 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -471,9 +472,9 @@ class TestDisableSegmentsFromIndexTask: db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() - ) + updated_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + ).all() for segment in updated_segments: assert segment.enabled is True diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index d4021143ef..b6e7e6e5c9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -12,10 +12,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy import delete, func, select, update from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -30,12 +31,12 @@ class DocumentIndexingSyncTaskTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.flush() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.flush() @@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask: """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) - db_session_with_containers.query(Document).where(Document.id == context["document"].id).update( - {"data_source_info": None} + db_session_with_containers.execute( + update(Document).where(Document.id == context["document"].id).values(data_source_info=None) ) db_session_with_containers.commit() @@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR @@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.COMPLETED @@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None @@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask: context = self._create_notion_sync_context(db_session_with_containers) def _delete_dataset_before_clean() -> str: - db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete() + db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id)) db_session_with_containers.commit() return "2024-01-02T00:00:00Z" @@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index d94abf2b40..a9a8c0f30c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import func, select from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -123,13 +124,13 @@ class TestDocumentIndexingUpdateTask: db_session_with_containers.expire_all() # Assert document status updated before reindex - updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() + updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1)) assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining == 0 @@ -167,8 +168,8 @@ class TestDocumentIndexingUpdateTask: mock_external_dependencies["runner_instance"].run.assert_called_once() # Segments should remain (since clean failed before DB delete) - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining > 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 6a8e186958..39c58987fd 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -317,7 +318,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -362,14 +363,14 @@ class TestDuplicateDocumentIndexingTasks: # Verify segments were deleted from database # Re-query segments from database using captured IDs to avoid stale ORM instances for seg_id in segment_ids: - deleted_segment = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == seg_id).limit(1) ) assert deleted_segment is None # Verify documents were updated to parsing status for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -438,7 +439,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -485,7 +486,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -543,7 +544,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() @@ -585,7 +586,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "limit" in updated_document.error.lower() @@ -649,7 +650,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -692,7 +693,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -736,7 +737,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -851,7 +852,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.is_paused is True assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.display_status == "paused" diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 1b4dcf28ea..95a867dbb5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,9 +3,6 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete from configs import dify_config @@ -13,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -21,6 +18,9 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index b5bef145d5..b43b622870 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,12 +2,12 @@ import uuid from unittest.mock import ANY, call, patch import pytest -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 4bc022c415..b00d827e37 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,16 +24,16 @@ from dataclasses import dataclass from datetime import timedelta import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.model import UploadFile from models.workflow import Workflow, WorkflowRun from repositories.sqlalchemy_api_workflow_run_repository import ( @@ -181,7 +181,7 @@ class TestWorkflowPauseIntegration: tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -190,7 +190,7 @@ class TestWorkflowPauseIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -696,7 +696,7 @@ class TestWorkflowPauseIntegration: tenant2 = Tenant( name="Test Tenant 2", - status="normal", + status=TenantStatus.NORMAL, ) self.session.add(tenant2) self.session.commit() @@ -705,7 +705,7 @@ class TestWorkflowPauseIntegration: email="test2@example.com", name="Test User 2", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) self.session.add(account2) self.session.commit() diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index d725fb990a..9c20118e27 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,7 +10,6 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient -from graphon.enums import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session @@ -25,6 +24,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus @@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log( ) # Mock quota to avoid rate limiting - from enums import quota_type + from services import quota_service - monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited()) # Execute schedule trigger workflow_schedule_tasks.run_schedule_trigger(plan.id) diff --git a/api/tests/unit_tests/commands/test_generate_swagger_specs.py b/api/tests/unit_tests/commands/test_generate_swagger_specs.py new file mode 100644 index 0000000000..e77e875081 --- /dev/null +++ b/api/tests/unit_tests/commands/test_generate_swagger_specs.py @@ -0,0 +1,37 @@ +"""Unit tests for the standalone Swagger export helper.""" + +import importlib.util +import json +import sys +from pathlib import Path + + +def _load_generate_swagger_specs_module(): + api_dir = Path(__file__).resolve().parents[3] + script_path = api_dir / "dev" / "generate_swagger_specs.py" + + spec = importlib.util.spec_from_file_location("generate_swagger_specs", script_path) + assert spec + assert spec.loader + + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) # type: ignore[attr-defined] + return module + + +def test_generate_specs_writes_console_web_and_service_swagger_files(tmp_path): + module = _load_generate_swagger_specs_module() + + written_paths = module.generate_specs(tmp_path) + + assert [path.name for path in written_paths] == [ + "console-swagger.json", + "web-swagger.json", + "service-swagger.json", + ] + + for path in written_paths: + payload = json.loads(path.read_text(encoding="utf-8")) + assert payload["swagger"] == "2.0" + assert "paths" in payload diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index d6933e2180..bad246a4bb 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -145,7 +145,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): - """Test that DB_EXTRAS options are properly merged with default timezone setting""" + """Test that DB_EXTRAS options are merged with the default timezone startup option.""" # Set environment variables monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") @@ -158,15 +158,28 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): # Create config config = DifyConfig() - # Get engine options - engine_options = config.SQLALCHEMY_ENGINE_OPTIONS - - # Verify options contains both search_path and timezone - options = engine_options["connect_args"]["options"] + options = config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options +def test_db_session_timezone_override_can_disable_app_level_timezone_injection(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema") + monkeypatch.setenv("DB_SESSION_TIMEZONE_OVERRIDE", "") + + config = DifyConfig() + + assert config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"] == { + "options": "-c search_path=myschema", + } + + def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): os.environ.clear() @@ -223,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest. _ = DifyConfig().normalized_pubsub_redis_url +def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "" + + +def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "enterprise-a" + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 55873b06a8..7174530e97 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine): configure_session_factory(_unit_test_engine, expire_on_commit=False) -def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): +def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner): """ - Helper to set up the mock DB execute chain for tenant/account authentication. + Helper to stub the tenant-owner execute result for service API app authentication. - This configures the mock to return (tenant, account) for the - db.session.execute(select(...).join().join().where()).one_or_none() - query used by validate_app_token decorator. + The validate_app_token decorator currently resolves the active tenant owner + via db.session.execute(select(Tenant, Account)...).one_or_none(). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_account: Mock account object to return + mock_owner: Mock owner object to return from the execute result """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner) -def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): +def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join): """ - Helper to set up the mock DB execute chain for dataset tenant authentication. + Helper to stub the tenant-owner execute result for dataset token authentication. - This configures the mock to return (tenant, tenant_account) for the - db.session.execute(select(...).where().where().where().where()).one_or_none() - query used by validate_dataset_token decorator. + The validate_dataset_token decorator currently resolves the owner mapping via + db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and + then loads the Account separately via db.session.get(...). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_ta: Mock tenant account object to return + mock_tenant_account_join: Mock tenant-account join object to return """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join) diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 9f1ff9b40f..bfa4048191 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation: csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with ( patch("services.annotation_service.current_account_with_tenant") as mock_auth, patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), @@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 0000000000..9c4678aed3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,139 @@ +"""Unit tests for console app import endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + return fake_session + + +class TestAppImportApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportApi() + + def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=True) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +class TestAppImportConfirmApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportConfirmApi() + + def test_import_confirm_returns_failed_status_and_rolls_back( + self, api, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + method = _unwrap(api.post) + + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 2ac3dc037d..35d07a987d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -138,12 +138,15 @@ def app_models(app_module): def patch_signed_url(monkeypatch, app_module): """Ensure icon URL generation uses a deterministic helper for tests.""" - def _fake_signed_url(key: str | None) -> str | None: - if not key: + def _fake_build_icon_url(_icon_type, key: str | None) -> str | None: + if key is None: + return None + icon_type = str(_icon_type).lower() + if icon_type != "image": return None return f"signed:{key}" - monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url) def _ts(hour: int = 12) -> datetime: diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index c52bc02420..2d218dac7e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,7 +4,6 @@ import io from types import SimpleNamespace import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -21,6 +20,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 11b3b3470d..24b7e39f73 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -33,12 +33,17 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1")) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -71,12 +76,17 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py deleted file mode 100644 index f588ab261d..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from controllers.console.app.conversation import _get_conversation - - -def test_get_conversation_mark_read_keeps_updated_at_unchanged(): - app_model = SimpleNamespace(id="app-id") - account = SimpleNamespace(id="account-id") - conversation = MagicMock() - conversation.id = "conversation-id" - - with ( - patch( - "controllers.console.app.conversation.current_account_with_tenant", - return_value=(account, None), - autospec=True, - ), - patch( - "controllers.console.app.conversation.naive_utc_now", - return_value=datetime(2026, 2, 9, 0, 0, 0), - autospec=True, - ), - patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, - ): - mock_session.scalar.return_value = conversation - - _get_conversation(app_model, "conversation-id") - - statement = mock_session.execute.call_args[0][0] - compiled = statement.compile() - sql_text = str(compiled).lower() - compact_sql_text = sql_text.replace(" ", "") - params = compiled.params - - assert "updated_at=current_timestamp" not in compact_sql_text - assert "updated_at=conversations.updated_at" in compact_sql_text - assert "read_at=:read_at" in compact_sql_text - assert "read_account_id=:read_account_id" in compact_sql_text - assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0) - assert params["read_account_id"] == "account-id" diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py new file mode 100644 index 0000000000..1a412aff29 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from contextlib import nullcontext +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from pydantic import ValidationError + +from controllers.console.app import conversation_variables as conversation_variables_module +from graphon.variables.types import SegmentType + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + created_at = datetime(2026, 1, 1, tzinfo=UTC) + updated_at = datetime(2026, 1, 2, tzinfo=UTC) + row = SimpleNamespace( + created_at=created_at, + updated_at=updated_at, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-1", + "name": "my_var", + "value_type": "string", + "value": "value", + "description": "desc", + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["page"] == 1 + assert response["limit"] == 100 + assert response["total"] == 1 + assert response["has_more"] is False + assert response["data"][0]["id"] == "var-1" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["updated_at"] == int(updated_at.timestamp()) + + +def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + row = SimpleNamespace( + created_at=None, + updated_at=None, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-2", + "name": "my_var_2", + "value_type": SegmentType.INTEGER, + "value": 42, + "description": None, + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["data"][0]["value_type"] == "number" + assert response["data"][0]["value"] == "42" + + +def test_get_conversation_variables_requires_conversation_id(app) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"): + with pytest.raises(ValidationError): + method(app_model=SimpleNamespace(id="app-1")) diff --git a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py index baac4cd4e0..1af15d8dc6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py +++ b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py @@ -1,6 +1,25 @@ import datetime +from types import SimpleNamespace +from unittest.mock import PropertyMock, patch -from controllers.console.app.mcp_server import AppMCPServerResponse +from flask import Flask + +from controllers.console import console_ns +from controllers.console.app.mcp_server import AppMCPServerController, AppMCPServerResponse + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class _ValidatedResponse: + def __init__(self, payload): + self._payload = payload + + def model_dump(self, mode="json"): + return self._payload class TestAppMCPServerResponse: @@ -40,6 +59,18 @@ class TestAppMCPServerResponse: resp = AppMCPServerResponse.model_validate(data) assert resp.parameters == {"already": "parsed"} + def test_parameters_json_array_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '["a", "b"]', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == ["a", "b"] + def test_timestamps_normalized(self): dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) data = { @@ -68,3 +99,40 @@ class TestAppMCPServerResponse: resp = AppMCPServerResponse.model_validate(data) assert resp.created_at is None assert resp.updated_at is None + + +class TestAppMCPServerController: + def test_get_returns_empty_dict_when_server_missing(self): + api = AppMCPServerController() + method = unwrap(api.get) + + with patch("controllers.console.app.mcp_server.db.session.scalar", return_value=None): + response = method(api, app_model=SimpleNamespace(id="app-1")) + + assert response == {} + + def test_post_returns_201(self): + api = AppMCPServerController() + method = unwrap(api.post) + payload = {"parameters": {"timeout": 30}} + app = Flask(__name__) + app.config["TESTING"] = True + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")), + patch("controllers.console.app.mcp_server.db.session.add"), + patch("controllers.console.app.mcp_server.db.session.commit"), + patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"), + patch( + "controllers.console.app.mcp_server.AppMCPServerResponse.model_validate", + return_value=_ValidatedResponse({"id": "server-1"}), + ), + ): + response, status_code = method( + api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description") + ) + + assert response == {"id": "server-1"} + assert status_code == 201 diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py index a76e958829..c984dbef5d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import UTC, datetime + import pytest from controllers.console.app import message as message_module @@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) assert len(response.data) == 2 assert response.data[0] == "What is AI?" + + +def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageDetailResponse normalizes alias fields and datetime timestamps.""" + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = message_module.MessageDetailResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "conversation_id": "550e8400-e29b-41d4-a716-446655440001", + "inputs": {"foo": "bar"}, + "query": "hello", + "re_sign_file_url_answer": "world", + "from_source": "user", + "status": "normal", + "created_at": created_at, + "message_metadata_dict": {"token_usage": 3}, + } + ) + assert response.answer == "world" + assert response.metadata == {"token_usage": 3} + assert response.created_at == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 3607636880..e91c0a0597 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -1,15 +1,16 @@ from __future__ import annotations +import json from datetime import datetime from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from graphon.file import File, FileTransferMethod, FileType def _unwrap(func): @@ -30,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: file_list = [ File( tenant_id="t1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", ) @@ -258,6 +259,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( assert exc.value.description == "invalid workflow graph" +def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.PublishedAllWorkflowApi() + handler = _unwrap(api.get) + + session_state = {"open": False} + + class _SessionContext: + def __enter__(self): + session_state["open"] = True + return object() + + def __exit__(self, exc_type, exc, tb): + session_state["open"] = False + return False + + class _SessionMaker: + def begin(self): + return _SessionContext() + + class _Workflow: + @property + def id(self): + assert session_state["open"] is True + return "w1" + + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker()) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False), + ), + ) + + def _fake_marshal(items, fields): + assert session_state["open"] is True + return [{"id": item.id} for item in items] + + monkeypatch.setattr(workflow_module, "marshal", _fake_marshal) + + with app.test_request_context( + "/apps/app/workflows", + method="GET", + query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"}, + ): + response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1")) + + assert response == { + "items": [{"id": "w1"}], + "page": 1, + "limit": 10, + "has_more": False, + } + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) @@ -290,3 +348,87 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk ): with pytest.raises(NotFound): handler(api, app_model=SimpleNamespace(id="app")) + + +def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None: + app_id_1 = "11111111-1111-1111-1111-111111111111" + app_id_2 = "22222222-2222-2222-2222-222222222222" + signed_avatar_url = "https://files.example.com/signed/avatar-1" + sign_avatar = Mock(return_value=signed_avatar_url) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=lambda app_ids, tenant_id: {app_id_1}), + ) + monkeypatch.setattr(workflow_module.file_helpers, "get_signed_file_url", sign_avatar) + + workflow_module.redis_client.hgetall.side_effect = lambda key: ( + { + b"sid-1": json.dumps( + { + "user_id": "u-1", + "username": "Alice", + "avatar": "avatar-file-id", + "sid": "sid-1", + } + ) + } + if key == f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + else {} + ) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={app_id_1},{app_id_2}", + method="GET", + ): + response = handler(api) + + assert response == { + "data": [ + { + "app_id": app_id_1, + "users": [ + { + "user_id": "u-1", + "username": "Alice", + "avatar": signed_avatar_url, + "sid": "sid-1", + } + ], + } + ] + } + workflow_module.redis_client.hgetall.assert_called_once_with( + f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + ) + sign_avatar.assert_called_once_with("avatar-file-id") + + +def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + accessible_app_ids = Mock(return_value=set()) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=accessible_app_ids), + ) + + excessive_ids = ",".join(f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS + 1)) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={excessive_ids}", + method="GET", + ): + with pytest.raises(HTTPException) as exc: + handler(api) + + assert exc.value.code == 400 + assert "Maximum" in exc.value.description + accessible_app_ids.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py new file mode 100644 index 0000000000..a9853f25b0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from controllers.console.app import workflow_app_log as workflow_app_log_module +from graphon.enums import WorkflowExecutionStatus + + +def test_workflow_app_log_query_parses_bool_and_datetime(): + query = workflow_app_log_module.WorkflowAppLogQuery.model_validate( + { + "detail": "true", + "created_at__before": "2026-01-02T03:04:05Z", + "page": "2", + "limit": "10", + } + ) + + assert query.detail is True + assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + assert query.page == 2 + assert query.limit == 10 + + +def test_workflow_app_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "log-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.SUCCEEDED, + "created_at": created_at, + "finished_at": created_at, + }, + "details": {"trigger_metadata": {}}, + "created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"}, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "succeeded" + assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + + +def test_workflow_archived_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "archived-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.FAILED, + }, + "trigger_metadata": {"type": "trigger-plugin"}, + "created_by_end_user": { + "id": "eu-1", + "type": "anonymous", + "is_anonymous": True, + "session_id": "session-1", + }, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "failed" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py new file mode 100644 index 0000000000..85afcf0e60 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_comment as workflow_comment_module +from controllers.console.app import wraps as app_wraps +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole + + +def _make_account(role: TenantAccountRole) -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = role + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _make_app() -> SimpleNamespace: + return SimpleNamespace(id="app-123", tenant_id="tenant-123", status="normal", mode="workflow") + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + monkeypatch.setattr(workflow_comment_module, "current_user", account) + + +def _patch_write_services(monkeypatch: pytest.MonkeyPatch) -> None: + for method_name in ( + "create_comment", + "update_comment", + "delete_comment", + "resolve_comment", + "validate_comment_access", + "create_reply", + "update_reply", + "delete_reply", + ): + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, method_name, MagicMock()) + + +def _patch_payload(payload: dict[str, object] | None): + if payload is None: + return nullcontext() + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + +@dataclass(frozen=True) +class WriteCase: + resource_cls: type + method_name: str + path: str + kwargs: dict[str, str] + payload: dict[str, object] | None = None + + +@pytest.mark.parametrize( + "case", + [ + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentListApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments", + kwargs={"app_id": "app-123"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentResolveApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/resolve", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + ), + ], +) +def test_write_endpoints_require_edit_permission(app: Flask, monkeypatch: pytest.MonkeyPatch, case: WriteCase) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.NORMAL) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + _patch_write_services(monkeypatch) + + with app.test_request_context(case.path, method=case.method_name.upper(), json=case.payload): + with _patch_payload(case.payload): + handler = getattr(case.resource_cls(), case.method_name) + with pytest.raises(Forbidden): + handler(**case.kwargs) + + +def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + create_comment_mock = MagicMock(return_value={"id": "comment-1"}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "create_comment", create_comment_mock) + payload = {"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments", method="POST", json=payload): + with _patch_payload(payload): + result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123") + + if isinstance(result, tuple): + response = result[0] + else: + response = result + assert response["id"] == "comment-1" + create_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + created_by="account-123", + content="hello", + position_x=1.0, + position_y=2.0, + mentioned_user_ids=[], + ) + + +def test_update_comment_omits_mentions_when_payload_does_not_include_them( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + update_comment_mock = MagicMock(return_value={"id": "comment-1", "updated_at": datetime(2024, 1, 1, 12, 0, 0)}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "update_comment", update_comment_mock) + payload = {"content": "hello", "position_x": 10.0, "position_y": 20.0} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments/comment-1", method="PUT", json=payload): + with _patch_payload(payload): + workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") + + update_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + comment_id="comment-1", + user_id="account-123", + content="hello", + position_x=10.0, + position_y=20.0, + mentioned_user_ids=None, + ) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index e11102acb1..c4a8148446 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py new file mode 100644 index 0000000000..5363aa154f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +from controllers.console.app import workflow_trigger as workflow_trigger_module + + +def test_parser_models_validate(): + parser = workflow_trigger_module.Parser(node_id="node-1") + enable_parser = workflow_trigger_module.ParserEnable( + trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True + ) + + assert parser.node_id == "node-1" + assert enable_parser.enable_trigger is True + + +def test_workflow_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + trigger = SimpleNamespace( + id="trigger-1", + trigger_type="trigger-plugin", + title="Trigger", + node_id="node-1", + provider_name="provider", + icon="https://example.com/icon", + status="enabled", + created_at=created_at, + updated_at=created_at, + ) + + payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump( + mode="json" + ) + assert payload["id"] == "trigger-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" + assert payload["updated_at"] == "2026-01-02T03:04:05Z" + + +def test_webhook_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + webhook = { + "id": "webhook-1", + "webhook_id": "whk-1", + "webhook_url": "https://example.com/hook", + "webhook_debug_url": "https://example.com/hook/debug", + "node_id": "node-1", + "created_at": created_at, + } + + payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json") + assert payload["webhook_id"] == "whk-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 740da1f1df..22b80b748e 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal -from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -16,6 +15,7 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", filename="test.jpg", @@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", filename="test.jpg", diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index cb4fe40944..17bee94c52 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -43,7 +43,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = True # Act @@ -76,7 +75,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Act with self.app.test_request_context( @@ -109,7 +107,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = False # Act @@ -135,7 +132,6 @@ class TestAuthenticationSecurity: def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db): """Test that reset password returns success with token for existing accounts.""" # Mock the setup check - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Test with existing account mock_get_user.return_value = MagicMock(email="existing@example.com") diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 9929a71120..b7bc73da5f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi: - IP rate limiting is checked """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "email_token_123" @@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is allowed by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = True @@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is blocked by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = False @@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi: - Prevents spam and abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.side_effect = AccountRegisterError("Account frozen") @@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi: - Defaults to en-US when not specified """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "token" @@ -286,7 +280,6 @@ class TestEmailCodeLoginApi: - User is logged in with token pair """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -335,7 +328,6 @@ class TestEmailCodeLoginApi: - User is logged in after account creation """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"} mock_get_user.return_value = None mock_create_account.return_value = mock_account @@ -369,7 +361,6 @@ class TestEmailCodeLoginApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -392,7 +383,6 @@ class TestEmailCodeLoginApi: - InvalidEmailError is raised when email doesn't match token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} # Act & Assert @@ -415,7 +405,6 @@ class TestEmailCodeLoginApi: - EmailCodeError is raised for wrong verification code """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} # Act & Assert @@ -453,7 +442,6 @@ class TestEmailCodeLoginApi: - User is added as owner of new workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -496,7 +484,6 @@ class TestEmailCodeLoginApi: - WorkspacesLimitExceeded is raised when limit reached """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -538,7 +525,6 @@ class TestEmailCodeLoginApi: - NotAllowedCreateWorkspace is raised when creation disabled """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 0cf97da878..d089be8905 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -110,7 +110,6 @@ class TestLoginApi: - Rate limit is reset after successful login """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -162,7 +161,6 @@ class TestLoginApi: - Authentication proceeds with invitation token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "test@example.com"}} mock_authenticate.return_value = mock_account @@ -199,7 +197,6 @@ class TestLoginApi: - EmailPasswordLoginLimitError is raised when limit exceeded """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True mock_get_invitation.return_value = None @@ -228,7 +225,6 @@ class TestLoginApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_frozen.return_value = True # Act & Assert @@ -268,7 +264,6 @@ class TestLoginApi: - Generic error message prevents user enumeration """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountPasswordError("Invalid password") @@ -305,7 +300,6 @@ class TestLoginApi: - Login is prevented even with valid credentials """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountLoginError("Account is banned") @@ -351,7 +345,6 @@ class TestLoginApi: - User cannot login without an assigned workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -383,7 +376,6 @@ class TestLoginApi: - Security check prevents invitation token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}} @@ -425,7 +417,6 @@ class TestLoginApi: mock_token_pair, ): """Test that login retries with lowercase email when uppercase lookup fails.""" - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account] @@ -459,7 +450,6 @@ class TestLoginApi: mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} mock_get_account.side_effect = Unauthorized("Account is banned.") @@ -513,7 +503,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_current_account.return_value = (mock_account, MagicMock()) # Act @@ -539,7 +528,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Create a mock anonymous user that will pass isinstance check anonymous_user = MagicMock() mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {}) diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index c80758c857..810f1b94fc 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -46,7 +46,6 @@ class TestPartnerTenants: patch("libs.login.dify_config.LOGIN_DISABLED", False), patch("libs.login.check_csrf_token") as mock_csrf, ): - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_csrf.return_value = None yield {"db": mock_db, "csrf": mock_csrf} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9c9f8da87c..5136922e88 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -18,6 +17,7 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 6ef8ccfdbd..63950736c5 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response -from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -16,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor ) from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 94d6c17915..9465936f28 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1772,6 +1772,21 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "http://localhost:5000/v1" + def test_get_api_base_url_no_double_v1(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com/v1", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + class TestDatasetRetrievalSettingApi: def test_get_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index ce2278de4f..d9b02ac453 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -215,17 +216,23 @@ class TestDatasetDocumentListApi: method = unwrap(api.post) payload = {"indexing_technique": "economy"} + created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy") + created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None) with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=created_dataset, + ), patch( "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", - return_value=([MagicMock()], "batch-1"), + return_value=([created_document], "batch-1"), ), ): response = method(api, "ds-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 726c0a5cf3..09ed2aaf69 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -99,6 +99,57 @@ class TestHitTestingApi: assert "records" in result assert result["records"] == [] + def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + } + records = [ + { + "segment": None, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } + ] + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": payload["query"], "records": records}, + ), + ): + result = method(api, dataset_id) + + assert result["query"] == payload["query"] + assert result["records"] == records + def test_hit_testing_dataset_not_found(self, app, dataset_id): api = HitTestingApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index 710c9be684..e4acd91b76 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -21,6 +20,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 66c9ba48c5..b4b57022e2 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,7 +2,6 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -20,6 +19,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 2e4ca4f2a4..145cc9cdd7 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -22,6 +21,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 02c7507ea7..76c863577a 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import controllers.console.explore.recommended_app as module +from models.model import AppMode, IconType def unwrap(func): @@ -90,3 +91,48 @@ class TestRecommendedAppApi: service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") assert result == result_data + + +class TestRecommendedAppResponseModels: + def test_recommended_app_info_response_computes_icon_url(self): + with patch.object(module, "build_icon_url", return_value="https://signed/icon.png"): + payload = module.RecommendedAppInfoResponse.model_validate( + { + "id": "app-1", + "name": "App", + "mode": AppMode.CHAT, + "icon": "icon.png", + "icon_type": IconType.IMAGE, + "icon_background": "#fff", + } + ).model_dump(mode="json") + + assert payload["icon_url"] == "https://signed/icon.png" + + def test_recommended_app_list_response_serialization(self): + response = module.RecommendedAppListResponse.model_validate( + { + "recommended_apps": [ + { + "app": { + "id": "app-1", + "name": "App", + "mode": "chat", + "icon": "icon.png", + "icon_type": "emoji", + "icon_background": "#fff", + }, + "app_id": "app-1", + "description": "desc", + "category": "cat", + "position": 1, + "is_listed": True, + "can_trial": False, + } + ], + "categories": ["cat"], + } + ).model_dump(mode="json") + + assert response["recommended_apps"][0]["app_id"] == "app-1" + assert response["categories"] == ["cat"] diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 04beb31389..3625056af9 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -26,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode @@ -94,7 +94,7 @@ class TestTrialAppWorkflowRunApi: with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(MagicMock(mode=AppMode.CHAT)) + method(api, MagicMock(mode=AppMode.CHAT)) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -106,7 +106,7 @@ class TestTrialAppWorkflowRunApi: patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(trial_app_workflow) + result = method(api, trial_app_workflow) assert result is not None @@ -124,7 +124,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_quota_exceeded(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -140,7 +140,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_model_not_support(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -156,7 +156,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_invoke_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -172,7 +172,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(CompletionRequestError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_rate_limit_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -188,7 +188,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_value_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -204,7 +204,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ValueError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_generic_exception(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -220,7 +220,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InternalServerError): - method(trial_app_workflow) + method(api, trial_app_workflow) class TestTrialChatApi: @@ -566,7 +566,7 @@ class TestTrialMessageSuggestedQuestionApi: with app.test_request_context("/"): with pytest.raises(NotChatAppError): - method(api, MagicMock(mode="completion"), str(uuid4())) + method(MagicMock(mode="completion"), str(uuid4())) def test_success(self, app, trial_app_chat, account): api = module.TrialMessageSuggestedQuestionApi() @@ -581,7 +581,7 @@ class TestTrialMessageSuggestedQuestionApi: return_value=["q1", "q2"], ), ): - result = method(api, trial_app_chat, str(uuid4())) + result = method(trial_app_chat, str(uuid4())) assert result == {"data": ["q1", "q2"]} @@ -599,7 +599,7 @@ class TestTrialMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(api, trial_app_chat, str(uuid4())) + method(trial_app_chat, str(uuid4())) class TestTrialAppParameterApi: @@ -931,7 +931,7 @@ class TestTrialAppWorkflowTaskStopApi: with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(trial_app_chat, str(uuid4())) + method(api, trial_app_chat, str(uuid4())) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowTaskStopApi() @@ -944,7 +944,7 @@ class TestTrialAppWorkflowTaskStopApi: patch.object(module.AppQueueManager, "set_stop_flag_no_user_check") as mock_set_flag, patch.object(module.GraphEngineManager, "send_stop_command") as mock_send_cmd, ): - result = method(trial_app_workflow, task_id) + result = method(api, trial_app_workflow, task_id) assert result == {"result": "success"} mock_set_flag.assert_called_once_with(task_id) diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index e89b89c8b1..6405558bb4 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -1,13 +1,17 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, PropertyMock, patch import pytest from flask import Flask from werkzeug.exceptions import Forbidden +import controllers.console.tag.tags as module from controllers.console import console_ns from controllers.console.tag.tags import ( - TagBindingCreateApi, - TagBindingDeleteApi, + DeprecatedTagBindingCreateApi, + DeprecatedTagBindingRemoveApi, + TagBindingCollectionApi, + TagBindingItemApi, TagListApi, TagUpdateDeleteApi, ) @@ -83,13 +87,20 @@ class TestTagListApi: ), patch( "controllers.console.tag.tags.TagService.get_tags", - return_value=[{"id": "1", "name": "tag"}], + return_value=[ + SimpleNamespace( + id="1", + name="tag", + type=TagType.KNOWLEDGE, + binding_count=1, + ) + ], ), ): result, status = method(api) assert status == 200 - assert isinstance(result, list) + assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}] def test_post_success(self, app, admin_user, tag, payload_patch): api = TagListApi() @@ -113,6 +124,7 @@ class TestTagListApi: assert status == 200 assert result["name"] == "test-tag" + assert result["binding_count"] == "0" def test_post_forbidden(self, app, readonly_user, payload_patch): api = TagListApi() @@ -158,7 +170,7 @@ class TestTagUpdateDeleteApi: result, status = method(api, "tag-1") assert status == 200 - assert result["binding_count"] == 3 + assert result["binding_count"] == "3" def test_patch_forbidden(self, app, readonly_user, payload_patch): api = TagUpdateDeleteApi() @@ -195,9 +207,9 @@ class TestTagUpdateDeleteApi: assert status == 204 -class TestTagBindingCreateApi: +class TestTagBindingCollectionApi: def test_create_success(self, app, admin_user, payload_patch): - api = TagBindingCreateApi() + api = TagBindingCollectionApi() method = unwrap(api.post) payload = { @@ -222,7 +234,7 @@ class TestTagBindingCreateApi: assert result["result"] == "success" def test_create_forbidden(self, app, readonly_user, payload_patch): - api = TagBindingCreateApi() + api = TagBindingCollectionApi() method = unwrap(api.post) with app.test_request_context("/", json={}): @@ -237,9 +249,78 @@ class TestTagBindingCreateApi: method(api) -class TestTagBindingDeleteApi: +class TestDeprecatedTagBindingCreateApi: + def test_create_success(self, app, admin_user, payload_patch): + api = DeprecatedTagBindingCreateApi() + method = unwrap(api.post) + + payload = { + "tag_ids": ["tag-1"], + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock, + ): + result, status = method(api) + + save_mock.assert_called_once() + assert status == 200 + assert result["result"] == "success" + + +class TestTagBindingItemApi: + def test_delete_success(self, app, admin_user, payload_patch): + api = TagBindingItemApi() + method = unwrap(api.delete) + + payload = { + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, + ): + result, status = method(api, "tag-1") + + delete_mock.assert_called_once() + delete_payload = delete_mock.call_args.args[0] + assert delete_payload.tag_id == "tag-1" + assert delete_payload.target_id == "target-1" + assert delete_payload.type == TagType.KNOWLEDGE + assert status == 200 + assert result["result"] == "success" + + def test_delete_forbidden(self, app, readonly_user): + api = TagBindingItemApi() + method = unwrap(api.delete) + + with app.test_request_context("/"): + with patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ): + with pytest.raises(Forbidden): + method(api, "tag-1") + + +class TestDeprecatedTagBindingRemoveApi: def test_remove_success(self, app, admin_user, payload_patch): - api = TagBindingDeleteApi() + api = DeprecatedTagBindingRemoveApi() method = unwrap(api.post) payload = { @@ -264,7 +345,7 @@ class TestTagBindingDeleteApi: assert result["result"] == "success" def test_remove_forbidden(self, app, readonly_user, payload_patch): - api = TagBindingDeleteApi() + api = DeprecatedTagBindingRemoveApi() method = unwrap(api.post) with app.test_request_context("/", json={}): @@ -277,3 +358,45 @@ class TestTagBindingDeleteApi: ): with pytest.raises(Forbidden): method(api) + + +class TestTagResponseModel: + def test_tag_response_normalizes_enum_type(self): + payload = module.TagResponse.model_validate( + {"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1} + ).model_dump(mode="json") + + assert payload["type"] == "knowledge" + assert payload["binding_count"] == "1" + + +class TestTagBindingRouteMetadata: + def test_legacy_write_routes_are_marked_deprecated(self): + assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True + assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True + assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True + assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True + + def test_write_routes_have_stable_operation_ids(self): + assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding" + assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding" + assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated" + assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated" + + def test_canonical_and_legacy_write_routes_are_registered(self): + route_map = { + resource.__name__: urls + for resource, urls, _route_doc, _kwargs in console_ns.resources + if resource.__name__ + in { + "TagBindingCollectionApi", + "TagBindingItemApi", + "DeprecatedTagBindingCreateApi", + "DeprecatedTagBindingRemoveApi", + } + } + + assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",) + assert route_map["TagBindingItemApi"] == ("/tag-bindings/",) + assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",) + assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",) diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py index 232b6eee79..ebf803cac9 100644 --- a/api/tests/unit_tests/controllers/console/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) handler(api, form_token="token") +def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + with pytest.raises(NotFoundError): + handler(api, form_token="token") + + def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None: submit_mock = Mock() form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE) diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index dd643faac9..0b1a32581a 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -11,7 +11,7 @@ from controllers.console.workspace.account import ( ChangeEmailSendEmailApi, CheckEmailUnique, ) -from models import Account +from models import Account, AccountStatus from services.account_service import AccountService @@ -24,16 +24,12 @@ def app(): return app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") account = Account(name=account_id, email=email) account.email = email account.id = account_id - account.status = "active" + account.status = AccountStatus.ACTIVE account._current_tenant = tenant_obj return account @@ -64,11 +60,13 @@ class TestChangeEmailSend: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) - mock_get_change_data.return_value = {"email": "current@example.com"} + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + } mock_send_email.return_value = "token-abc" with app.test_request_context( @@ -85,12 +83,54 @@ class TestChangeEmailSend: email="new@example.com", old_email="current@example.com", language="en-US", - phase="new_email", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, ) mock_extract_ip.assert_called_once() mock_is_ip_limit.assert_called_once_with("127.0.0.1") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailSendEmailApi().post() + + mock_send_email.assert_not_called() + class TestChangeEmailValidity: @patch("controllers.console.wraps.db") @@ -117,12 +157,16 @@ class TestChangeEmailValidity: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } mock_generate_token.return_value = (None, "new-token") with app.test_request_context( @@ -138,11 +182,166 @@ class TestChangeEmailValidity: mock_add_rate.assert_not_called() mock_revoke_token.assert_called_once_with("token-123") mock_generate_token.assert_called_once_with( - "user@example.com", code="1234", old_email="old@example.com", additional_data={} + "user@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + }, ) mock_reset_rate.assert_called_once_with("user@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_upgrade_new_phase_token_to_new_verified( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "new@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + } + mock_generate_token.return_value = (None, "new-verified-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "new@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"} + mock_generate_token.assert_called_once_with( + "new@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + }, + ) + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_phase_is_unknown( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token whose phase marker is a string but not a known transition must be rejected.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_has_no_phase( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + class TestChangeEmailReset: @patch("controllers.console.wraps.db") @@ -169,13 +368,16 @@ class TestChangeEmailReset: mock_db, app, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_get_data.return_value = { + "email": "new@example.com", + "old_email": "OLD@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } mock_account_after_update = _build_account("new@example.com", "acc3-updated") mock_update_account.return_value = mock_account_after_update @@ -194,12 +396,158 @@ class TestChangeEmailReset: mock_send_notify.assert_called_once_with(email="new@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_phase_is_not_new_verified( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + # Simulate a token straight out of step #1 (phase=old_email) — exactly + # the replay used in the advisory PoC. + mock_get_data.return_value = { + "email": "old@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-from-step1"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_email_differs_from_payload_new_email( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """A verified token for address A must not be replayed to change to address B.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = { + "email": "verified@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-verified"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + +class TestAccountServiceSendChangeEmailEmail: + """Service-level coverage for the phase-bound changes in `send_change_email_email`.""" + + def test_should_raise_value_error_for_invalid_phase(self): + with pytest.raises(ValueError, match="phase must be one of"): + AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + phase="old_email_verified", + ) + + @patch("services.account_service.send_change_mail_task") + @patch("services.account_service.AccountService.change_email_rate_limiter") + @patch("services.account_service.AccountService.generate_change_email_token") + def test_should_stamp_phase_into_generated_token( + self, + mock_generate_token, + mock_rate_limiter, + mock_mail_task, + ): + mock_rate_limiter.is_rate_limited.return_value = False + mock_generate_token.return_value = ("123456", "the-token") + + returned = AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + language="en-US", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, + ) + + assert returned == "the-token" + mock_generate_token.assert_called_once_with( + "user@example.com", + None, + old_email="user@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + }, + ) + mock_mail_task.delay.assert_called_once() + mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com") + class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") def test_should_normalize_feedback_email(self, mock_update, mock_db, app): - _mock_wraps_db(mock_db) with app.test_request_context( "/account/delete/feedback", method="POST", @@ -216,7 +564,6 @@ class TestCheckEmailUnique: @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): - _mock_wraps_db(mock_db) mock_is_freeze.return_value = False mock_check_unique.return_value = True diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 239fec8430..811bf5b1e7 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask, g @@ -16,10 +16,6 @@ def app(): return flask_app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_feature_flags(): placeholder_quota = SimpleNamespace(limit=0, size=0) workspace_members = SimpleNamespace(is_available=lambda count: True) @@ -49,7 +45,6 @@ class TestMemberInviteEmailApi: mock_get_features, app, ): - _mock_wraps_db(mock_db) mock_get_features.return_value = _build_feature_flags() mock_invite_member.return_value = "token-abc" diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index f6e096a97b..aa4973851a 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -310,7 +310,6 @@ class TestSystemSetup: def test_should_allow_when_setup_complete(self, mock_db): """Test that requests are allowed when setup is complete""" # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists @setup_required def admin_view(): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py index 51f76af172..0b3d7ef6d7 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -2,14 +2,17 @@ from unittest.mock import MagicMock, patch import pytest +from controllers.console import console_ns from controllers.console.workspace.endpoint import ( - EndpointCreateApi, - EndpointDeleteApi, + DeprecatedEndpointCreateApi, + DeprecatedEndpointDeleteApi, + DeprecatedEndpointUpdateApi, + EndpointCollectionApi, EndpointDisableApi, EndpointEnableApi, + EndpointItemApi, EndpointListApi, EndpointListForSinglePluginApi, - EndpointUpdateApi, ) from core.plugin.impl.exc import PluginPermissionDeniedError @@ -35,9 +38,9 @@ def patch_current_account(user_and_tenant): @pytest.mark.usefixtures("patch_current_account") -class TestEndpointCreateApi: +class TestEndpointCollectionApi: def test_create_success(self, app): - api = EndpointCreateApi() + api = EndpointCollectionApi() method = unwrap(api.post) payload = { @@ -55,7 +58,7 @@ class TestEndpointCreateApi: assert result["success"] is True def test_create_permission_denied(self, app): - api = EndpointCreateApi() + api = EndpointCollectionApi() method = unwrap(api.post) payload = { @@ -75,7 +78,7 @@ class TestEndpointCreateApi: method(api) def test_create_validation_error(self, app): - api = EndpointCreateApi() + api = EndpointCollectionApi() method = unwrap(api.post) payload = { @@ -91,6 +94,27 @@ class TestEndpointCreateApi: method(api) +@pytest.mark.usefixtures("patch_current_account") +class TestDeprecatedEndpointCreateApi: + def test_create_success(self, app): + api = DeprecatedEndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {"a": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + @pytest.mark.usefixtures("patch_current_account") class TestEndpointListApi: def test_list_success(self, app): @@ -146,9 +170,96 @@ class TestEndpointListForSinglePluginApi: @pytest.mark.usefixtures("patch_current_account") -class TestEndpointDeleteApi: +class TestEndpointItemApi: def test_delete_success(self, app): - api = EndpointDeleteApi() + api = EndpointItemApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/", method="DELETE"), + patch( + "controllers.console.workspace.endpoint.EndpointService.delete_endpoint", + return_value=True, + ) as mock_delete, + ): + result = method(api, "e1") + + assert result["success"] is True + mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1") + + def test_delete_service_failure(self, app): + api = EndpointItemApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/", method="DELETE"), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False), + ): + result = method(api, "e1") + + assert result["success"] is False + + def test_update_success(self, app): + api = EndpointItemApi() + method = unwrap(api.patch) + + payload = { + "name": "new-name", + "settings": {"x": 1}, + } + + with ( + app.test_request_context("/", method="PATCH", json=payload), + patch( + "controllers.console.workspace.endpoint.EndpointService.update_endpoint", + return_value=True, + ) as mock_update, + ): + result = method(api, "e1") + + assert result["success"] is True + mock_update.assert_called_once_with( + tenant_id="t1", + user_id="u1", + endpoint_id="e1", + name="new-name", + settings={"x": 1}, + ) + + def test_update_validation_error(self, app): + api = EndpointItemApi() + method = unwrap(api.patch) + + payload = {"settings": {}} + + with ( + app.test_request_context("/", method="PATCH", json=payload), + ): + with pytest.raises(ValueError): + method(api, "e1") + + def test_update_service_failure(self, app): + api = EndpointItemApi() + method = unwrap(api.patch) + + payload = { + "name": "n", + "settings": {}, + } + + with ( + app.test_request_context("/", method="PATCH", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False), + ): + result = method(api, "e1") + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestDeprecatedEndpointDeleteApi: + def test_delete_success(self, app): + api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -162,7 +273,7 @@ class TestEndpointDeleteApi: assert result["success"] is True def test_delete_invalid_payload(self, app): - api = EndpointDeleteApi() + api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) with ( @@ -172,7 +283,7 @@ class TestEndpointDeleteApi: method(api) def test_delete_service_failure(self, app): - api = EndpointDeleteApi() + api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -187,9 +298,9 @@ class TestEndpointDeleteApi: @pytest.mark.usefixtures("patch_current_account") -class TestEndpointUpdateApi: +class TestDeprecatedEndpointUpdateApi: def test_update_success(self, app): - api = EndpointUpdateApi() + api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) payload = { @@ -207,7 +318,7 @@ class TestEndpointUpdateApi: assert result["success"] is True def test_update_validation_error(self, app): - api = EndpointUpdateApi() + api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) payload = {"endpoint_id": "e1", "settings": {}} @@ -219,7 +330,7 @@ class TestEndpointUpdateApi: method(api) def test_update_service_failure(self, app): - api = EndpointUpdateApi() + api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) payload = { @@ -237,6 +348,36 @@ class TestEndpointUpdateApi: assert result["success"] is False +class TestEndpointRouteMetadata: + def test_legacy_write_routes_are_marked_deprecated(self): + assert DeprecatedEndpointCreateApi.post.__apidoc__["deprecated"] is True + assert DeprecatedEndpointDeleteApi.post.__apidoc__["deprecated"] is True + assert DeprecatedEndpointUpdateApi.post.__apidoc__["deprecated"] is True + assert EndpointCollectionApi.post.__apidoc__.get("deprecated") is not True + assert EndpointItemApi.delete.__apidoc__.get("deprecated") is not True + assert EndpointItemApi.patch.__apidoc__.get("deprecated") is not True + + def test_canonical_and_legacy_write_routes_are_registered(self): + route_map = { + resource.__name__: urls + for resource, urls, _route_doc, _kwargs in console_ns.resources + if resource.__name__ + in { + "EndpointCollectionApi", + "EndpointItemApi", + "DeprecatedEndpointCreateApi", + "DeprecatedEndpointDeleteApi", + "DeprecatedEndpointUpdateApi", + } + } + + assert route_map["EndpointCollectionApi"] == ("/workspaces/current/endpoints",) + assert route_map["EndpointItemApi"] == ("/workspaces/current/endpoints/",) + assert route_map["DeprecatedEndpointCreateApi"] == ("/workspaces/current/endpoints/create",) + assert route_map["DeprecatedEndpointDeleteApi"] == ("/workspaces/current/endpoints/delete",) + assert route_map["DeprecatedEndpointUpdateApi"] == ("/workspaces/current/endpoints/update",) + + @pytest.mark.usefixtures("patch_current_account") class TestEndpointEnableApi: def test_enable_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index 9c42ee9529..b2f949c6e2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,9 +11,10 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView +from werkzeug.exceptions import Forbidden + from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index fb9eec98cb..168479af1e 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -14,6 +13,7 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index c829327bc7..f0d32f81fb 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,8 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -16,6 +14,8 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 44feacf2ad..1422f29849 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None @contextmanager def _mock_db(): - mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True) with patch("extensions.ext_database.db.session", mock_session): yield diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index b2d13dbbdf..e82a29f045 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -18,6 +18,7 @@ from controllers.console.workspace.workspace import ( CustomConfigWorkspaceApi, SwitchWorkspaceApi, TenantApi, + TenantInfoResponse, TenantListApi, WebappLogoWorkspaceApi, WorkspaceInfoApi, @@ -435,6 +436,23 @@ class TestTenantApi: assert status == 200 +class TestTenantInfoResponse: + def test_tenant_info_response_normalizes_enum_and_datetime(self): + created_at = naive_utc_now() + payload = TenantInfoResponse.model_validate( + { + "id": "t1", + "status": TenantStatus.NORMAL, + "plan": CloudPlan.TEAM, + "created_at": created_at, + } + ).model_dump(mode="json") + + assert payload["status"] == "normal" + assert payload["plan"] == "team" + assert payload["created_at"] == int(created_at.timestamp()) + + class TestSwitchWorkspaceApi: def test_switch_success(self, app): api = SwitchWorkspaceApi() diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 974d8f7bc6..71381e6a2b 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -18,6 +18,7 @@ from controllers.inner_api.app.dsl import ( InnerAppDSLImportPayload, _get_active_account, ) +from models.account import AccountStatus from services.app_dsl_service import ImportStatus @@ -63,7 +64,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_active_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "active" + mock_account.status = AccountStatus.ACTIVE mock_db.session.scalar.return_value = mock_account result = _get_active_account("user@example.com") @@ -74,7 +75,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_inactive_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "banned" + mock_account.status = AccountStatus.BANNED mock_db.session.scalar.return_value = mock_account result = _get_active_account("banned@example.com") @@ -102,16 +103,16 @@ class TestEnterpriseAppDSLImport: @pytest.fixture def _mock_import_deps(self): - """Patch db, sessionmaker, and AppDslService for import handler tests.""" - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock()) - mock_session_ctx.__exit__ = MagicMock(return_value=False) - mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx))) + """Patch db, Session, and AppDslService for import handler tests.""" + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=False) with ( patch("controllers.inner_api.app.dsl.db"), - patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker), + patch("controllers.inner_api.app.dsl.Session", return_value=mock_session), patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, ): + self._mock_session = mock_session self._mock_dsl = MagicMock() mock_dsl_cls.return_value = self._mock_dsl yield @@ -147,6 +148,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 200 assert body["status"] == "completed" mock_account.set_tenant_id.assert_called_once_with("ws-123") + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -162,6 +165,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 202 assert body["status"] == "pending" + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -177,6 +182,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 400 assert body["status"] == "failed" + self._mock_session.rollback.assert_called_once_with() + self._mock_session.commit.assert_not_called() @patch("controllers.inner_api.app.dsl._get_active_account") def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 957d7fbd9b..d1b09c3a58 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -2,6 +2,7 @@ Unit tests for inner_api plugin decorators """ +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -40,17 +41,22 @@ class TestTenantUserPayload: class TestGetUser: """Test get_user function""" + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_return_existing_user_by_id( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test returning existing user when found by ID""" # Arrange mock_user = MagicMock() mock_user.id = "user123" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -58,13 +64,45 @@ class TestGetUser: # Assert assert result == mock_user - mock_session.get.assert_called_once() + mock_session.scalar.assert_called_once() + @patch("controllers.inner_api.plugin.wraps.select") + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_not_resolve_non_anonymous_users_across_tenants( + self, + mock_db, + mock_sessionmaker, + mock_enduser_class, + mock_select, + app: Flask, + ): + """Test that explicit user IDs remain scoped to the current tenant.""" + # Arrange + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + mock_new_user = MagicMock() + mock_new_user.tenant_id = "tenant-current" + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant-current", "foreign-user-id") + + # Assert + assert result == mock_new_user + mock_session.get.assert_not_called() + mock_session.scalar.assert_called_once() + mock_session.add.assert_called_once_with(mock_new_user) + + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_return_existing_anonymous_user_by_session_id( - self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask ): """Test returning existing anonymous user by session_id""" # Arrange @@ -72,8 +110,9 @@ class TestGetUser: mock_user.session_id = "anonymous_session" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - # non-anonymous path uses session.get(); anonymous uses session.scalar() - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -82,17 +121,22 @@ class TestGetUser: # Assert assert result == mock_user + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_create_new_user_when_not_found( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test creating new user when not found in database""" # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = None + mock_session.scalar.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -133,7 +177,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.side_effect = Exception("Database error") + mock_session.scalar.side_effect = Exception("Database error") # Act & Assert with app.app_context(): @@ -232,11 +276,11 @@ class TestGetUserTenant: class PluginTestPayload: """Simple test payload class""" - def __init__(self, data: dict): + def __init__(self, data: dict[str, Any]): self.value = data.get("value") @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): return cls(data) @@ -277,7 +321,7 @@ class TestPluginData: # Arrange class InvalidPayload: @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): raise Exception("Validation failed") @plugin_data(payload_type=InvalidPayload) diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 56a8f94963..7d2193adc6 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -20,6 +20,7 @@ from controllers.inner_api.workspace.workspace import ( WorkspaceCreatePayload, WorkspaceOwnerlessPayload, ) +from models.account import TenantStatus class TestWorkspaceCreatePayload: @@ -98,7 +99,7 @@ class TestEnterpriseWorkspace: mock_tenant.id = "tenant-id" mock_tenant.name = "My Workspace" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.created_at = now mock_tenant.updated_at = now mock_tenant_svc.create_tenant.return_value = mock_tenant @@ -162,7 +163,7 @@ class TestEnterpriseWorkspaceNoOwnerEmail: mock_tenant.name = "My Workspace" mock_tenant.encrypt_public_key = "pub-key" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.custom_config = None mock_tenant.created_at = now mock_tenant.updated_at = now diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index 1507bf7a5f..f5d93b5ac3 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -10,8 +10,9 @@ from flask import Flask from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi from controllers.service_api.app.error import AppUnavailableError +from models.account import TenantStatus from models.model import App, AppMode -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result class TestAppParameterApi: @@ -62,7 +63,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL # Mock DB queries for app and tenant mock_db.session.get.side_effect = [ @@ -73,7 +74,7 @@ class TestAppParameterApi: # Mock tenant owner info for login mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -110,7 +111,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -119,7 +120,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -151,7 +152,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -160,7 +161,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -190,7 +191,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -199,7 +200,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -253,7 +254,7 @@ class TestAppMetaApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -262,7 +263,7 @@ class TestAppMetaApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -321,7 +322,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -330,7 +331,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -378,7 +379,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -387,7 +388,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -424,7 +425,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -433,7 +434,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -476,7 +477,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -485,7 +486,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index a26fea8fbd..c16ebad739 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,7 +13,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -30,6 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 57681d8f5b..3364c07e62 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,7 +16,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -35,6 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index dbd06677d8..4fb8ecf784 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -15,6 +15,7 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch @@ -29,11 +30,16 @@ from controllers.service_api.app.conversation import ( ConversationRenameApi, ConversationRenamePayload, ConversationVariableDetailApi, + ConversationVariableInfiniteScrollPaginationResponse, + ConversationVariableResponse, ConversationVariablesApi, ConversationVariablesQuery, ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from fields._value_type_serializer import serialize_value_type +from graphon.variables import StringSegment +from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ( @@ -261,6 +267,72 @@ class TestConversationVariableUpdatePayload: assert payload.value == nested +class TestConversationVariableResponseModels: + def test_variable_response_normalizes_value_type_and_timestamps(self): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "description": "desc", + "created_at": created_at, + "updated_at": created_at, + } + ) + assert response.value_type == "number" + assert response.value == "1" + assert response.created_at == int(created_at.timestamp()) + assert response.updated_at == int(created_at.timestamp()) + + def test_variable_response_normalizes_string_value_type_alias(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER.value, + } + ) + + assert response.value_type == "number" + + def test_variable_response_normalizes_callable_exposed_type(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()), + } + ) + + assert response.value_type == "string" + + def test_serialize_value_type_supports_segments_and_mappings(self): + assert serialize_value_type(StringSegment(value="hello")) == "string" + assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number" + + def test_variable_pagination_response(self): + response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( + { + "limit": 1, + "has_more": False, + "data": [ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": "string", + "value": "bar", + } + ], + } + ) + assert response.limit == 1 + assert response.has_more is False + assert len(response.data) == 1 + assert response.data[0].name == "foo" + + class TestConversationAppModeValidation: """Test app mode validation for conversation endpoints.""" @@ -549,6 +621,44 @@ class TestConversationVariablesApiController: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "get_conversational_variable", + lambda *_args, **_kwargs: SimpleNamespace( + limit=1, + has_more=False, + data=[ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + } + ], + ), + ) + + api = ConversationVariablesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables?limit=20", + method="GET", + ): + result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + + assert result["limit"] == 1 + assert result["has_more"] is False + assert result["data"][0]["value_type"] == "number" + assert result["data"][0]["value"] == "1" + assert result["data"][0]["created_at"] == int(created_at.timestamp()) + class TestConversationVariableDetailApiController: def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -602,3 +712,41 @@ class TestConversationVariableDetailApiController: c_id="00000000-0000-0000-0000-000000000001", variable_id="00000000-0000-0000-0000-000000000002", ) + + def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "update_conversation_variable", + lambda *_args, **_kwargs: { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + }, + ) + + api = ConversationVariableDetailApi() + handler = _unwrap(api.put) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables/2", + method="PUT", + json={"value": 1}, + ): + result = handler( + api, + app_model=app_model, + end_user=end_user, + c_id="00000000-0000-0000-0000-000000000001", + variable_id="00000000-0000-0000-0000-000000000002", + ) + + assert result["id"] == "550e8400-e29b-41d4-a716-446655440000" + assert result["value_type"] == "number" + assert result["value"] == "1" + assert result["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py new file mode 100644 index 0000000000..846d5368f3 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py @@ -0,0 +1,707 @@ +"""Dedicated tests for HITL behavior exposed through the Service API.""" + +from __future__ import annotations + +import json +import sys +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import ANY, MagicMock, Mock + +import pytest + +import services.app_generate_service as ags_module +from controllers.service_api.app.workflow_events import WorkflowEventsApi +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.common import workflow_response_converter +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, + HumanInputRequiredResponse, + WorkflowAppPausedBlockingResponse, + WorkflowPauseStreamResponse, +) +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper +from core.workflow.human_input_policy import HumanInputSurface +from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType +from graphon.runtime import GraphRuntimeState, VariablePool +from models.account import Account +from models.enums import CreatorUserRole +from models.model import AppMode +from models.workflow import WorkflowRun +from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot +from repositories.entities.workflow_pause import WorkflowPauseEntity +from services.app_generate_service import AppGenerateService +from services.workflow_event_snapshot_service import _build_snapshot_events +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +class _DummyRateLimit: + @staticmethod + def gen_request_key() -> str: + return "dummy-request-id" + + def __init__(self, client_id: str, max_active_requests: int) -> None: + self.client_id = client_id + self.max_active_requests = max_active_requests + + def enter(self, request_id: str | None = None) -> str: + return request_id or "dummy-request-id" + + def exit(self, request_id: str) -> None: + return None + + def generate(self, generator, request_id: str): + return generator + + +def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): + workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"] + repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run) + monkeypatch.setattr( + workflow_events_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: repo, + ) + monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object())) + return workflow_events_module + + +def _build_service_api_pause_converter() -> WorkflowResponseConverter: + application_generate_entity = SimpleNamespace( + inputs={}, + files=[], + invoke_from=InvokeFrom.SERVICE_API, + app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), + ) + system_variables = build_system_variables( + user_id="user", + app_id="app-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + ) + user = MagicMock(spec=Account) + user.id = "account-id" + user.name = "Tester" + user.email = "tester@example.com" + return WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + user=user, + system_variables=system_variables, + ) + + +def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse: + data = AdvancedChatPausedBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + workflow_run_id="run-1", + answer="partial", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + paused_nodes=["node-1"], + reasons=[ + { + "type": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "expiration_time": 100, + } + ], + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ) + return AdvancedChatPausedBlockingResponse(task_id="t1", data=data) + + +def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse: + return WorkflowAppPausedBlockingResponse( + task_id="t1", + workflow_run_id="r1", + data=WorkflowAppPausedBlockingResponse.Data( + id="r1", + workflow_id="wf-1", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + error=None, + elapsed_time=0.5, + total_tokens=0, + total_steps=2, + created_at=1, + finished_at=None, + paused_nodes=["node-1"], + reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}], + ), + ) + + +@dataclass(frozen=True) +class _FakePauseEntity(WorkflowPauseEntity): + pause_id: str + workflow_run_id: str + paused_at_value: datetime + pause_reasons: Sequence[HumanInputRequired] + + @property + def id(self) -> str: + return self.pause_id + + @property + def workflow_execution_id(self) -> str: + return self.workflow_run_id + + def get_state(self) -> bytes: + raise AssertionError("state is not required for snapshot tests") + + @property + def resumed_at(self) -> datetime | None: + return None + + @property + def paused_at(self) -> datetime: + return self.paused_at_value + + def get_pause_reasons(self) -> Sequence[HumanInputRequired]: + return self.pause_reasons + + +def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun: + return WorkflowRun( + id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + type="workflow", + triggered_from="app-run", + version="v1", + graph=None, + inputs=json.dumps({"input": "value"}), + status=status, + outputs=json.dumps({}), + error=None, + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + + +def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot: + created_at = datetime(2024, 1, 1, tzinfo=UTC) + finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) + return WorkflowNodeExecutionSnapshot( + execution_id="exec-1", + node_id="node-1", + node_type="human-input", + title="Human Input", + index=1, + status=status.value, + elapsed_time=0.5, + created_at=created_at, + finished_at=finished_at, + iteration_id=None, + loop_id=None, + ) + + +def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-1", + app_id="app-1", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-1", + ) + generate_entity = WorkflowAppGenerateEntity( + task_id=task_id, + app_config=app_config, + inputs={}, + files=[], + user_id="user-1", + stream=True, + invoke_from=InvokeFrom.EXPLORE, + call_depth=0, + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + runtime_state.register_paused_node("node-1") + runtime_state.outputs = {"result": "value"} + wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) + return WorkflowResumptionContext( + generate_entity=wrapper, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + +class TestHitlServiceApi: + # Service API event-stream continuation + def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + msg_generator.retrieve_events.return_value = ["raw-event"] + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: streamed\n\n" + msg_generator.retrieve_events.assert_called_once_with( + AppMode.WORKFLOW, + "run-1", + terminal_events=[], + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) + + def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open( + self, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"]) + snapshot_builder = Mock(return_value=["snapshot-events"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true", + method="GET", + ): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: snapshot\n\n" + msg_generator.retrieve_events.assert_not_called() + snapshot_builder.assert_called_once_with( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=ANY, + human_input_surface=HumanInputSurface.SERVICE_API, + close_on_pause=False, + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"]) + + def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit) + + workflow = MagicMock() + workflow.created_by = "owner-id" + monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow) + monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker") + + generator_instance = MagicMock() + generator_instance.generate.return_value = {"result": "advanced-blocking"} + generator_instance.convert_to_event_stream.side_effect = lambda payload: payload + monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance) + + app_model = MagicMock() + app_model.mode = AppMode.ADVANCED_CHAT + app_model.id = "app-id" + app_model.tenant_id = "tenant-id" + app_model.max_active_requests = 0 + app_model.is_agent = False + + user = MagicMock() + user.id = "user-id" + + result = AppGenerateService.generate( + app_model=app_model, + user=user, + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + assert result == {"result": "advanced-blocking"} + call_kwargs = generator_instance.generate.call_args.kwargs + assert call_kwargs["streaming"] is False + assert call_kwargs["pause_state_config"] is not None + assert call_kwargs["pause_state_config"].session_factory == "session-maker" + assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id" + + # Blocking payload contract + def test_advanced_chat_blocking_pause_payload_contract(self) -> None: + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter + + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response( + _build_advanced_chat_paused_blocking_response() + ) + + assert response["event"] == "workflow_paused" + assert response["workflow_run_id"] == "run-1" + assert response["answer"] == "partial" + assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED + assert response["data"]["reasons"][0]["expiration_time"] == 100 + assert "human_input_forms" not in response["data"] + + def test_workflow_blocking_pause_payload_contract(self) -> None: + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter + + response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response( + _build_workflow_paused_blocking_response() + ) + + assert response["workflow_run_id"] == "r1" + assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED + assert response["data"]["paused_nodes"] == ["node-1"] + assert response["data"]["reasons"] == [ + {"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100} + ] + assert "human_input_forms" not in response["data"] + + def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None: + from core.app.app_config.entities import AppAdditionalFeatures + from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline + from models.enums import MessageStatus + from models.model import EndUser + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}), + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace( + id="message-id", + query="hello", + created_at=datetime.utcnow(), + status=MessageStatus.NORMAL, + answer="", + ), + user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"), + stream=False, + dialogue_count=1, + draft_var_saver_factory=lambda **kwargs: None, + ) + pipeline._task_state.answer = "partial answer" + pipeline._workflow_run_id = "run-id" + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Approval", + form_content="Need approval", + inputs=[], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + form_token="token-1", + resolved_default_values={}, + expiration_time=123, + ), + ) + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run-id", + paused_nodes=["node-1"], + outputs={}, + reasons=[ + { + "type": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "node_id": "node-1", + "expiration_time": 123, + }, + ], + status="paused", + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, AdvancedChatPausedBlockingResponse) + assert response.data.answer == "partial answer" + assert response.data.workflow_run_id == "run-id" + assert response.data.reasons[0]["form_id"] == "form-1" + assert response.data.reasons[0]["expiration_time"] == 123 + + def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None: + from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module + from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}), + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + user=SimpleNamespace(id="user", session_id="session"), + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000) + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="content", + expiration_time=1, + ), + ) + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + paused_nodes=["node-1"], + reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}], + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, WorkflowAppPausedBlockingResponse) + assert response.data.status == WorkflowExecutionStatus.PAUSED + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}] + + def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None: + converter = _build_service_api_pause_converter() + converter.workflow_start_to_stream_response( + task_id="task", + workflow_run_id="run-id", + workflow_id="workflow-id", + reason=WorkflowStartReason.INITIAL, + ) + + expiration_time = datetime(2024, 1, 1, tzinfo=UTC) + + class _FakeSession: + def execute(self, _stmt): + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession()) + monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + workflow_response_converter, + "load_form_tokens_by_form_id", + lambda form_ids, session=None, surface=None: {"form-1": "token"}, + ) + + reason = HumanInputRequired( + form_id="form-1", + form_content="Rendered", + inputs=[ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), + ], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + node_id="node-id", + node_title="Human Step", + form_token="token", + ) + queue_event = QueueWorkflowPausedEvent( + reasons=[reason], + outputs={"answer": "value"}, + paused_nodes=["node-id"], + ) + + runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0) + responses = converter.workflow_pause_to_stream_response( + event=queue_event, + task_id="task", + graph_runtime_state=runtime_state, + ) + + assert isinstance(responses[-1], WorkflowPauseStreamResponse) + pause_resp = responses[-1] + assert pause_resp.workflow_run_id == "run-id" + assert pause_resp.data.paused_nodes == ["node-id"] + assert pause_resp.data.outputs == {} + assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required" + assert pause_resp.data.reasons[0]["form_id"] == "form-1" + assert pause_resp.data.reasons[0]["form_token"] == "token" + assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp()) + + assert isinstance(responses[0], HumanInputRequiredResponse) + hi_resp = responses[0] + assert hi_resp.data.form_id == "form-1" + assert hi_resp.data.node_id == "node-id" + assert hi_resp.data.node_title == "Human Step" + assert hi_resp.data.inputs[0].output_variable_name == "field" + assert hi_resp.data.actions[0].id == "approve" + assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "token" + assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) + + # Snapshot payload contract + def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) + resumption_context = _build_resumption_context("task-ctx") + monkeypatch.setattr( + "services.workflow_event_snapshot_service.load_form_tokens_by_form_id", + lambda form_ids, session=None, surface=None: {"form-1": "wtok"}, + ) + + class _SessionContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return False + + def session_maker() -> _SessionContext: + return _SessionContext( + SimpleNamespace( + execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')], + ) + ) + + pause_entity = _FakePauseEntity( + pause_id="pause-1", + workflow_run_id="run-1", + paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), + pause_reasons=[ + HumanInputRequired( + form_id="form-1", + form_content="content", + node_id="node-1", + node_title="Human Input", + form_token="wtok", + ) + ], + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-ctx", + message_context=None, + pause_entity=pause_entity, + resumption_context=resumption_context, + session_maker=session_maker, + ) + + assert [event["event"] for event in events] == [ + "workflow_started", + "node_started", + "node_finished", + "human_input_required", + "workflow_paused", + ] + assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value + assert events[3]["data"]["form_token"] == "wtok" + assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + pause_data = events[-1]["data"] + assert pause_data["paused_nodes"] == ["node-1"] + assert pause_data["outputs"] == {"result": "value"} + assert pause_data["reasons"][0]["TYPE"] == "human_input_required" + assert pause_data["reasons"][0]["form_token"] == "wtok" + assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value + assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) + assert pause_data["elapsed_time"] == workflow_run.elapsed_time + assert pause_data["total_tokens"] == workflow_run.total_tokens + assert pause_data["total_steps"] == workflow_run.total_steps diff --git a/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py new file mode 100644 index 0000000000..531f722ceb --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py @@ -0,0 +1,184 @@ +"""Unit tests for Service API human input form endpoints.""" + +from __future__ import annotations + +import json +import sys +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi +from models.human_input import RecipientType +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +class TestWorkflowHumanInputFormApi: + def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + definition = SimpleNamespace( + model_dump=lambda: { + "rendered_content": "Rendered form content", + "inputs": [{"output_variable_name": "name"}], + "default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}}, + "user_actions": [{"id": "approve", "title": "Approve"}], + } + ) + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + get_definition=lambda: definition, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + response = handler(api, app_model=app_model, form_token="token-1") + + payload = json.loads(response.get_data(as_text=True)) + assert payload == { + "form_content": "Rendered form content", + "inputs": [{"output_variable_name": "name"}], + "resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'}, + "user_actions": [{"id": "approve", "title": "Approve"}], + "expiration_time": int(form.expiration_time.timestamp()), + } + service_mock.get_form_by_token.assert_called_once_with("token-1") + service_mock.ensure_form_active.assert_called_once_with(form) + + def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace( + app_id="another-app", + tenant_id="tenant-1", + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, form_token="token-1") + + @pytest.mark.parametrize( + "recipient_type", + [ + RecipientType.CONSOLE, + RecipientType.BACKSTAGE, + RecipientType.EMAIL_MEMBER, + RecipientType.EMAIL_EXTERNAL, + ], + ) + def test_get_rejects_non_service_api_recipient_types( + self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType + ) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=recipient_type, + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, form_token="token-1") + + service_mock.ensure_form_active.assert_not_called() + + def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/form/human_input/token-1", + method="POST", + json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"}, + ): + response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1") + + assert response == {} + assert status == 200 + service_mock.submit_form_by_token.assert_called_once_with( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token-1", + selected_action_id="approve", + form_data={"name": "Alice"}, + submission_end_user_id="end-user-1", + ) + + @pytest.mark.parametrize( + "recipient_type", + [ + RecipientType.CONSOLE, + RecipientType.BACKSTAGE, + RecipientType.EMAIL_MEMBER, + RecipientType.EMAIL_EXTERNAL, + ], + ) + def test_post_rejects_non_service_api_recipient_types( + self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType + ) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=recipient_type, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/form/human_input/token-1", + method="POST", + json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"}, + ): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, form_token="token-1") + + service_mock.submit_form_by_token.assert_not_called() diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index cfa21bf2dd..da09ec13ce 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -15,11 +15,11 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -36,6 +36,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService +def _make_mock_workflow_run(run_id: str = "run-1"): + run = Mock() + run.id = run_id + run.workflow_id = "wf-1" + run.status = WorkflowExecutionStatus.SUCCEEDED + run.inputs = {"input": "value"} + run.outputs_dict = {"output": "value"} + run.error = None + run.total_steps = 1 + run.total_tokens = 10 + run.created_at = datetime(2026, 1, 1, tzinfo=UTC) + run.finished_at = datetime(2026, 1, 1, tzinfo=UTC) + run.elapsed_time = 0.1 + return run + + class TestWorkflowRunPayload: """Test suite for WorkflowRunPayload Pydantic model.""" @@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi: handler(api, app_model=app_model, workflow_run_id="run") def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - run = SimpleNamespace(id="run") + run = _make_mock_workflow_run(run_id="run") repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run) workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) @@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi: handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") - assert handler(api, app_model=app_model, workflow_run_id="run") == run + result = handler(api, app_model=app_model, workflow_run_id="run") + assert result["id"] == "run" + assert result["workflow_id"] == "wf-1" + assert result["status"] == "succeeded" class TestWorkflowRunApi: @@ -490,7 +510,7 @@ class TestWorkflowAppLogApi: monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", - lambda *_args, **_kwargs: {"items": [], "total": 0}, + lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}, ) api = WorkflowAppLogApi() @@ -500,7 +520,7 @@ class TestWorkflowAppLogApi: with app.test_request_context("/workflows/logs", method="GET"): response = handler(api, app_model=app_model) - assert response == {"items": [], "total": 0} + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} # ============================================================================= @@ -527,9 +547,8 @@ def mock_workflow_app(): class TestWorkflowRunDetailApiGet: """Test suite for WorkflowRunDetailApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``) - and ``@service_api_ns.marshal_with``. We call the unwrapped method - directly; ``marshal_with`` is a no-op when calling directly. + ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``), + and we call the unwrapped method directly in tests. """ @patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory") @@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet: mock_workflow_app, ): """Test successful workflow run detail retrieval.""" - mock_run = Mock() - mock_run.id = "run-1" - mock_run.status = "succeeded" + mock_run = _make_mock_workflow_run(run_id="run-1") mock_repo = Mock() mock_repo.get_workflow_run_by_id.return_value = mock_run mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo @@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet: api = WorkflowRunDetailApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) - assert result == mock_run + assert result["id"] == mock_run.id + assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") def test_get_workflow_run_wrong_app_mode(self, mock_db, app): @@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost: class TestWorkflowAppLogApiGet: """Test suite for WorkflowAppLogApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` and - ``@service_api_ns.marshal_with``. + ``get`` is wrapped by ``@validate_app_token``. """ @patch("controllers.service_api.app.workflow.WorkflowAppService") @@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet: ): """Test successful workflow log retrieval.""" mock_pagination = Mock() + mock_pagination.page = 1 + mock_pagination.limit = 20 + mock_pagination.total = 0 + mock_pagination.has_more = False mock_pagination.data = [] mock_svc_instance = Mock() mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination @@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet: api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app) - assert result == mock_pagination + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py new file mode 100644 index 0000000000..f45a7f9632 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py @@ -0,0 +1,166 @@ +"""Unit tests for Service API workflow event stream endpoints.""" + +from __future__ import annotations + +import json +import sys +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.service_api.app.error import NotWorkflowAppError +from controllers.service_api.app.workflow_events import WorkflowEventsApi +from models.enums import CreatorUserRole +from models.model import AppMode +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): + workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"] + repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run) + monkeypatch.setattr( + workflow_events_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: repo, + ) + monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object())) + return workflow_events_module + + +class TestWorkflowEventsApi: + def test_wrong_app_mode(self, app) -> None: + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotWorkflowAppError): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + _mock_repo_for_run(monkeypatch, workflow_run=None) + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="another-user", + finished_at=None, + ) + _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=datetime(2099, 1, 1, tzinfo=UTC), + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + monkeypatch.setattr( + workflow_events_module.WorkflowResponseConverter, + "workflow_run_result_to_finish_response", + lambda **_kwargs: SimpleNamespace( + model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"}, + event=SimpleNamespace(value="workflow_finished"), + ), + ) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.mimetype == "text/event-stream" + body = response.get_data(as_text=True).strip() + assert body.startswith("data: ") + payload = json.loads(body[len("data: ") :]) + assert payload["task_id"] == "run-1" + assert payload["event"] == "workflow_finished" + + def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + msg_generator.retrieve_events.return_value = ["raw-event"] + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: streamed\n\n" + msg_generator.retrieve_events.assert_called_once_with( + AppMode.WORKFLOW, + "run-1", + terminal_events=None, + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) + + def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"]) + snapshot_builder = Mock(return_value=["snapshot-events"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: snapshot\n\n" + msg_generator.retrieve_events.assert_not_called() + snapshot_builder.assert_called_once() + workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"]) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 4b8e3a738c..eda270258d 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,7 @@ from types import SimpleNamespace -from graphon.enums import WorkflowExecutionStatus - from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField +from graphon.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index eddba5a517..8c89812cb4 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -15,7 +15,10 @@ from flask import Flask from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import ( + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, +) @pytest.fixture @@ -123,9 +126,7 @@ class AuthenticationMocker: mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: - mock_ta = Mock() - mock_ta.account_id = mock_account.id - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @staticmethod def setup_dataset_auth(mock_db, mock_tenant, mock_account): @@ -133,8 +134,7 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) mock_db.session.get.return_value = mock_account diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 12d5e7345d..1b391e67ec 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -22,6 +22,8 @@ import pytest from werkzeug.exceptions import Forbidden, NotFound from controllers.service_api.dataset.document import ( + DeprecatedDocumentAddByTextApi, + DeprecatedDocumentUpdateByTextApi, DocumentAddByFileApi, DocumentAddByTextApi, DocumentApi, @@ -699,8 +701,8 @@ class TestDocumentApiDelete: ``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which internally calls ``validate_and_get_api_token``. To bypass the decorator we call the original function via ``__wrapped__`` (preserved by - ``functools.wraps``). ``delete`` queries the dataset via - ``db.session.query(Dataset)`` directly, so we patch ``db`` at the + ``functools.wraps``). ``delete`` loads the dataset via + ``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the controller module. """ @@ -1005,7 +1007,7 @@ class TestDocumentAddByTextApi: # Act with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={ "name": "Test Document", @@ -1037,7 +1039,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1066,7 +1068,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1093,6 +1095,20 @@ class TestArchivedDocumentImmutableError: assert error.code == 403 +class TestDocumentTextRouteDeprecation: + """Test that legacy underscore text routes stay marked deprecated.""" + + def test_create_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy create-by-text alias is marked deprecated.""" + assert DeprecatedDocumentAddByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentAddByTextApi.post.__apidoc__.get("deprecated") is not True + + def test_update_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy update-by-text alias is marked deprecated.""" + assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True + + # ============================================================================= # Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi, # DocumentUpdateByFileApi. @@ -1162,7 +1178,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Updated Doc", "text": "New content"}, headers={"Authorization": "Bearer test_token"}, @@ -1195,7 +1211,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Doc", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py deleted file mode 100644 index c0b40d070a..0000000000 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Unit tests for Service API Site controller -""" - -import uuid -from unittest.mock import Mock, patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.service_api.app.site import AppSiteApi -from models.account import TenantStatus -from models.model import App, Site -from tests.unit_tests.conftest import setup_mock_tenant_account_query - - -class TestAppSiteApi: - """Test suite for AppSiteApi""" - - @pytest.fixture - def mock_app_model(self): - """Create a mock App model with tenant.""" - app = Mock(spec=App) - app.id = str(uuid.uuid4()) - app.tenant_id = str(uuid.uuid4()) - app.status = "normal" - app.enable_api = True - - mock_tenant = Mock() - mock_tenant.id = app.tenant_id - mock_tenant.status = TenantStatus.NORMAL - app.tenant = mock_tenant - - return app - - @pytest.fixture - def mock_site(self): - """Create a mock Site model.""" - site = Mock(spec=Site) - site.id = str(uuid.uuid4()) - site.app_id = str(uuid.uuid4()) - site.title = "Test Site" - site.icon = "icon-url" - site.icon_background = "#ffffff" - site.description = "Site description" - site.copyright = "Copyright 2024" - site.privacy_policy = "Privacy policy text" - site.custom_disclaimer = "Custom disclaimer" - site.default_language = "en-US" - site.prompt_public = True - site.show_workflow_steps = True - site.use_icon_as_answer_icon = False - site.chat_color_theme = "light" - site.chat_color_theme_inverted = False - site.icon_type = "image" - site.created_at = "2024-01-01T00:00:00" - site.updated_at = "2024-01-01T00:00:00" - return site - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_success( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test successful retrieval of site configuration.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - # Mock wraps.db for authentication - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site.db for site query - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - response = api.get() - - # Assert - assert response["title"] == "Test Site" - assert response["icon"] == "icon-url" - assert response["description"] == "Site description" - mock_db.session.scalar.assert_called_once() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_not_found( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - ): - """Test that Forbidden is raised when site is not found.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query to return None - mock_db.session.scalar.return_value = None - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_tenant_archived( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test that Forbidden is raised when tenant is archived.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query - mock_db.session.scalar.return_value = mock_site - - # Set tenant status to archived AFTER authentication - mock_app_model.tenant.status = TenantStatus.ARCHIVE - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_queries_by_app_id( - self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model - ): - """Test that site is queried using the app model's id.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - mock_site = Mock(spec=Site) - mock_site.id = str(uuid.uuid4()) - mock_site.app_id = mock_app_model.id - mock_site.title = "Test Site" - mock_site.icon = "icon-url" - mock_site.icon_background = "#ffffff" - mock_site.description = "Site description" - mock_site.copyright = "Copyright 2024" - mock_site.privacy_policy = "Privacy policy text" - mock_site.custom_disclaimer = "Custom disclaimer" - mock_site.default_language = "en-US" - mock_site.prompt_public = True - mock_site.show_workflow_steps = True - mock_site.use_icon_as_answer_icon = False - mock_site.chat_color_theme = "light" - mock_site.chat_color_theme_inverted = False - mock_site.icon_type = "image" - mock_site.created_at = "2024-01-01T00:00:00" - mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - api.get() - - # Assert - # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index a2008e024b..6dfbdcf98e 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan from models.account import TenantStatus from models.model import ApiToken from tests.unit_tests.conftest import ( - setup_mock_dataset_tenant_query, - setup_mock_tenant_account_query, + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, ) @@ -141,14 +141,11 @@ class TestValidateAppToken: mock_account = Mock() mock_account.id = str(uuid.uuid4()) - mock_ta = Mock() - mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant via session.get() mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query (execute(select(...)).one_or_none()) - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + # Mock the tenant owner execute result (execute(select(...)).one_or_none()) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @validate_app_token def protected_view(app_model): @@ -471,7 +468,7 @@ class TestValidateDatasetToken: mock_account.current_tenant = mock_tenant # Mock the tenant account join query (execute(select(...)).one_or_none()) - setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) # Mock the account lookup via session.get() mock_db.session.get.return_value = mock_account diff --git a/api/tests/unit_tests/controllers/web/conftest.py b/api/tests/unit_tests/controllers/web/conftest.py index 274d78c9cf..b7f3244c6c 100644 --- a/api/tests/unit_tests/controllers/web/conftest.py +++ b/api/tests/unit_tests/controllers/web/conftest.py @@ -22,18 +22,16 @@ class FakeSession: def __init__(self, mapping: dict[str, Any] | None = None): self._mapping: dict[str, Any] = mapping or {} - self._model_name: str | None = None - def query(self, model: type) -> FakeSession: - self._model_name = model.__name__ - return self + def get(self, model: type, _ident: object) -> Any: + return self._mapping.get(model.__name__) - def where(self, *_args: object, **_kwargs: object) -> FakeSession: - return self - - def first(self) -> Any: - assert self._model_name is not None - return self._mapping.get(self._model_name) + def scalar(self, stmt: Any) -> Any: + try: + model = stmt.column_descriptions[0]["entity"] + except (AttributeError, IndexError, KeyError, TypeError): + return None + return self._mapping.get(model.__name__) class FakeDB: diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index cbfc8fa613..a6ca441801 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -22,6 +21,7 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index 49039d03fe..4f8d848637 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,7 +7,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -19,6 +18,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index a1dbc80b20..5f2dc19aab 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -36,18 +36,6 @@ class _FakeSession: def __init__(self, mapping: dict[str, Any]): self._mapping = mapping - self._model_name: str | None = None - - def query(self, model): - self._model_name = model.__name__ - return self - - def where(self, *args, **kwargs): - return self - - def first(self): - assert self._model_name is not None - return self._mapping.get(self._model_name) def get(self, model, ident): return self._mapping.get(model.__name__) diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index a01587d64a..13b953c04d 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -34,7 +34,6 @@ def _patch_wraps(): patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), patch("controllers.web.login.dify_config", web_dify), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index bc7aea0ef9..cde8820e00 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError +from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index 97206019b9..ea8cc8aa86 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index defc8b4b64..2f5873d865 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,6 +1,8 @@ import json import pytest + +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -8,8 +10,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner - # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index a44a0650eb..17ab5babcb 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,6 +3,11 @@ from typing import Any from unittest.mock import MagicMock import pytest + +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, @@ -11,11 +16,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent - # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index 5ee66da94a..186b4a501d 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -12,6 +10,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index e2f3c16335..d9fe7004ff 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) +from graphon.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 8bde9c1f97..11b53dd0f9 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,8 +1,7 @@ +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager - def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py index dd00c3defc..0a0ffe657c 100644 --- a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py +++ b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py @@ -77,6 +77,38 @@ class TestAdditionalFeatureManagers: SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( {"suggested_questions_after_answer": {"enabled": "yes"}} ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "prompt": 123}} + ) + with pytest.raises(ValueError, match="must be less than or equal to 1000 characters"): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "prompt": "a" * 1001}} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "model": "bad"}} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "model": {"provider": "openai"}}} + ) + + validated_config, _ = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + { + "suggested_questions_after_answer": { + "enabled": True, + "prompt": "custom prompt", + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"max_tokens": 1024}, + }, + } + } + ) + assert validated_config["suggested_questions_after_answer"]["prompt"] == "custom prompt" + assert validated_config["suggested_questions_after_answer"]["model"]["name"] == "gpt-4o-mini" assert ( SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}}) diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index 000f83cd5a..f2bc3076da 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 1fb0dc6cf1..370f7abb8b 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { @@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool @@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock ConversationVariable.from_variable to return mock objects @@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index e9fdeefee4..6debeb4fdd 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,9 +1,10 @@ from collections.abc import Generator -from graphon.enums import WorkflowNodeExecutionStatus +import pytest from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, @@ -12,6 +13,8 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: @@ -29,6 +32,37 @@ class TestAdvancedChatGenerateResponseConverter: response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) assert "usage" not in response["metadata"] + def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch): + data = AdvancedChatPausedBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + workflow_run_id="run-1", + answer="partial", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + paused_nodes=["node-1"], + reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}], + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ) + original_model_dump = type(data).model_dump + + def _model_dump_with_future_field(self, *args, **kwargs): + payload = original_model_dump(self, *args, **kwargs) + payload["future_field"] = "future-value" + return payload + + monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field) + blocking = AdvancedChatPausedBlockingResponse(task_id="t1", data=data) + + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert response["data"]["future_field"] == "future-value" + def test_stream_simple_response_includes_node_events(self): node_start = NodeStartStreamResponse( task_id="t1", diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index a6d8598955..99a386cd45 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,8 +6,6 @@ from types import SimpleNamespace from unittest import mock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +17,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 82b2e51019..64bcfa9a18 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -4,8 +4,6 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.advanced_chat.generate_task_pipeline import ( @@ -41,14 +39,20 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, AnnotationReply, AnnotationReplyAccount, + HumanInputRequiredResponse, MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import UserAction +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser @@ -123,6 +127,57 @@ class TestAdvancedChatGenerateTaskPipeline: assert response.data.answer == "done" assert response.data.metadata == {"k": "v"} + def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): + pipeline = _make_pipeline() + pipeline._task_state.answer = "partial answer" + pipeline._workflow_run_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + start_at=0.0, + total_tokens=7, + node_run_steps=3, + ) + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Approval", + form_content="Need approval", + inputs=[], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + form_token="token-1", + resolved_default_values={}, + expiration_time=123, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, AdvancedChatPausedBlockingResponse) + assert response.data.workflow_run_id == "run-id" + assert response.data.status == "paused" + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [ + { + "TYPE": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "node_id": "node-1", + "node_title": "Approval", + "form_content": "Need approval", + "inputs": [], + "actions": [{"id": "approve", "title": "Approve", "button_style": "default"}], + "display_in_ui": True, + "form_token": "token-1", + "resolved_default_values": {}, + "expiration_time": 123, + } + ] + def test_handle_text_chunk_event_updates_state(self): pipeline = _make_pipeline() pipeline._message_cycle_manager = SimpleNamespace( diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 7dc4358150..80f7f94b1a 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 08250bc3b6..4567b35480 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 68bcffb0e8..8f3c41701b 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,7 +2,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -10,6 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f255d2c7df..b3ea1a464f 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index 4a94a2b4f1..201923e0e4 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,11 +1,11 @@ from types import SimpleNamespace import pytest -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.runtime import GraphRuntimeState, VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 328cd12f12..dd6cd0e919 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,10 +1,9 @@ from collections.abc import Mapping, Sequence +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter - class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" @@ -12,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: def create_test_file(self, file_id: str = "test_file_1") -> File: """Create a test File object""" return File( - id=file_id, - type=FileType.DOCUMENT, + file_id=file_id, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", filename=f"{file_id}.txt", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index bc11bf4174..1bef6f69cd 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,13 +1,12 @@ from datetime import UTC, datetime from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index c9e146ff12..936ac37e55 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,11 +1,10 @@ from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 0fde7565d2..b3c0eb74fa 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,8 +10,6 @@ from typing import Any from unittest.mock import Mock import pytest -from graphon.entities import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -27,6 +25,8 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 619d66085a..aa2085177e 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 96af9fbdee..f2e35f9900 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 6cdcab29ab..cfe797aa76 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -12,6 +10,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py index 9a2dc38f74..c36edf48fc 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker): workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={}) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = workflow + session.get.return_value = workflow mocker.patch.object(module.db, "session", session) queue_manager = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 4fe82efcb3..9db83f5531 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -14,6 +13,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) +from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index c8ae288e6f..603062a51c 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: @@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker): app_generate_entity.single_iteration_run = None app_generate_entity.single_loop_run = None - query = MagicMock() - query.where.return_value.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.get.side_effect = [None, None] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( @@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker): app_generate_entity = _build_app_generate_entity() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline session = MagicMock() - session.query.return_value = query_pipeline + session.get.side_effect = [None, pipeline] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py new file mode 100644 index 0000000000..560652f8cb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from collections.abc import Generator + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.task_entities import ( + AppStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from graphon.enums import WorkflowExecutionStatus + + +class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]): + blocking_full_calls: list[WorkflowAppBlockingResponse] = [] + blocking_simple_calls: list[WorkflowAppBlockingResponse] = [] + stream_full_calls: list[Generator[AppStreamResponse, None, None]] = [] + stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = [] + + @classmethod + def reset(cls) -> None: + cls.blocking_full_calls = [] + cls.blocking_simple_calls = [] + cls.stream_full_calls = [] + cls.stream_simple_calls = [] + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: + cls.blocking_full_calls.append(blocking_response) + return {"kind": "blocking-full", "task_id": blocking_response.task_id} + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: + cls.blocking_simple_calls.append(blocking_response) + return {"kind": "blocking-simple", "task_id": blocking_response.task_id} + + @classmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + cls.stream_full_calls.append(stream_response) + yield {"kind": "stream-full"} + + @classmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + cls.stream_simple_calls.append(stream_response) + yield {"kind": "stream-simple"} + + +def _build_blocking_response() -> WorkflowAppBlockingResponse: + return WorkflowAppBlockingResponse( + task_id="task-1", + workflow_run_id="run-1", + data=WorkflowAppBlockingResponse.Data( + id="run-1", + workflow_id="workflow-1", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=0.1, + total_tokens=0, + total_steps=1, + created_at=1, + finished_at=2, + ), + ) + + +def _build_stream_response() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse( + workflow_run_id="run-1", + stream_response=PingStreamResponse(task_id="task-1"), + ) + + +def test_convert_routes_blocking_response_by_invoke_from() -> None: + _DummyConverter.reset() + blocking_response = _build_blocking_response() + + full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API) + simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP) + + assert full_result == {"kind": "blocking-full", "task_id": "task-1"} + assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"} + assert _DummyConverter.blocking_full_calls == [blocking_response] + assert _DummyConverter.blocking_simple_calls == [blocking_response] + + +def test_convert_routes_stream_response_by_invoke_from() -> None: + _DummyConverter.reset() + + full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API)) + simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP)) + + assert full_result == [{"kind": "stream-full"}] + assert simple_result == [{"kind": "stream-simple"}] + assert len(_DummyConverter.stream_full_calls) == 1 + assert len(_DummyConverter.stream_simple_calls) == 1 diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 6167be3bbd..b0f8b423e1 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,9 +476,8 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from graphon.enums import BuiltinNodeTypes - from core.app.entities.app_invoke_entities import InvokeFrom + from graphon.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index 1dee7fdab6..17de39ca99 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,15 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -23,6 +14,15 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_message_generator.py b/api/tests/unit_tests/core/app/apps/test_message_generator.py index 25377e633e..90c9abf35c 100644 --- a/api/tests/unit_tests/core/app/apps/test_message_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_message_generator.py @@ -1,6 +1,7 @@ from unittest.mock import Mock, patch from core.app.apps.message_generator import MessageGenerator +from core.app.entities.task_entities import StreamEvent from models.model import AppMode @@ -23,7 +24,21 @@ class TestMessageGenerator: "core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}]) ) as mock_stream, ): - events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2)) + events = list( + MessageGenerator.retrieve_events( + AppMode.WORKFLOW, + "run-1", + idle_timeout=1, + ping_interval=2, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) + ) assert events == [{"event": "ping"}] - mock_stream.assert_called_once() + mock_stream.assert_called_once_with( + topic="topic", + idle_timeout=1, + ping_interval=2, + on_subscribe=None, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index a126bc85f7..6104b8d6ca 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -4,9 +4,14 @@ from types import ModuleType, SimpleNamespace from typing import Any import graphon.nodes.human_input.entities # noqa: F401 +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import node_factory as node_factory_module +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph @@ -25,12 +30,6 @@ from graphon.nodes.base.node import Node from graphon.nodes.end.entities import EndNodeData from graphon.nodes.start.entities import StartNodeData from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -56,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]): def version(cls) -> str: return "1" - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) + def __init__( + self, + node_id: str, + config: _StubToolNodeData, + *, + graph_init_params, + graph_runtime_state, + **_kwargs: Any, + ) -> None: + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) def _get_error_strategy(self): return self._node_data.error_strategy @@ -90,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node + original_resolve_node_class = node_factory_module.resolve_workflow_node_class - def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) - node_data = typed_node_config["data"] - if node_data.type == BuiltinNodeTypes.TOOL: - return _StubToolNode( - id=str(typed_node_config["id"]), - config=typed_node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, typed_node_config) + def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + if node_type == BuiltinNodeTypes.TOOL: + return _StubToolNode + return original_resolve_node_class(node_type=node_type, node_version=node_version) - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py index a7714c56ce..58f0e47a4b 100644 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -88,6 +88,10 @@ def test_normalize_terminal_events_defaults(): } +def test_normalize_terminal_events_empty_values(): + assert _normalize_terminal_events([]) == set({}) + + def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): topic = FakeTopic() times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0] @@ -106,3 +110,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): assert next(generator) == StreamEvent.PING.value # next receive yields None -> ping interval triggers assert next(generator) == StreamEvent.PING.value + + +def test_stream_topic_events_can_continue_past_pause(): + topic = FakeTopic() + topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode()) + topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode()) + + generator = stream_topic_events( + topic=topic, + idle_timeout=1.0, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) + + assert next(generator) == StreamEvent.PING.value + assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value + assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value + with pytest.raises(StopIteration): + next(generator) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index de5bca161c..58c7bfa4bc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -4,6 +4,23 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.system_variables import default_system_variables from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -24,23 +41,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.variables import StringVariable -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueLoopCompletedEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.system_variables import default_system_variables - class TestWorkflowBasedAppRunner: def test_resolve_user_from(self): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index aa789d9ff3..10fb2271f4 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 9e30faecf2..620a153204 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 8a717e1dcc..a3ab379b66 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,11 +3,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -16,6 +11,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index b768e813bd..7dd7ffd727 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -11,6 +9,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 29df903aa8..1f6e7e12ef 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,15 +2,14 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState - from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index d91bb85aee..0bcc1029b0 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -5,8 +5,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -38,15 +36,18 @@ from core.app.entities.queue_entities import ( ) from core.app.entities.task_entities import ( ErrorStreamResponse, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, + WorkflowAppPausedBlockingResponse, WorkflowFinishStreamResponse, - WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser @@ -91,27 +92,50 @@ def _make_pipeline(): class TestWorkflowGenerateTaskPipeline: - def test_to_blocking_response_handles_pause(self): + def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + start_at=0.0, + total_tokens=5, + node_run_steps=2, + ) def _gen(): - yield WorkflowPauseStreamResponse( + yield HumanInputRequiredResponse( task_id="task", - workflow_run_id="run", - data=WorkflowPauseStreamResponse.Data( - workflow_run_id="run", - status=WorkflowExecutionStatus.PAUSED, - outputs={}, - created_at=1, - elapsed_time=0.1, - total_tokens=0, - total_steps=0, + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="content", + expiration_time=1, ), ) response = pipeline._to_blocking_response(_gen()) + assert isinstance(response, WorkflowAppPausedBlockingResponse) + assert response.workflow_run_id == "run-id" assert response.data.status == WorkflowExecutionStatus.PAUSED + assert response.data.created_at == 0 + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [ + { + "TYPE": "human_input_required", + "form_id": "form-1", + "node_id": "node-1", + "node_title": "Human Input", + "form_content": "content", + "inputs": [], + "actions": [], + "display_in_ui": False, + "form_token": None, + "resolved_default_values": {}, + "expiration_time": 1, + } + ] def test_to_blocking_response_handles_finish(self): pipeline = _make_pipeline() diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 014a0cba72..7c79780641 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,11 +1,10 @@ -from graphon.enums import WorkflowNodeExecutionStatus - from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) +from graphon.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index a78c1b428f..ba55e8f695 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,6 +1,9 @@ from collections.abc import Sequence from unittest.mock import Mock +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.command_channels import CommandChannel from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent @@ -8,10 +11,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import ReadOnlyGraphRuntimeState from graphon.variables import StringVariable from graphon.variables.segments import Segment, StringSegment - -from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035e64325b..539944d683 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,16 @@ from time import time from unittest.mock import Mock import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.workflow.system_variables import SystemVariableKey from graphon.entities.pause_reason import SchedulingPause from graphon.graph_engine.entities.commands import GraphEngineCommand from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -15,16 +25,6 @@ from graphon.graph_events import ( ) from graphon.runtime import ReadOnlyVariablePool from graphon.variables.segments import Segment - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import ( - PauseStatePersistenceLayer, - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index 95931f4f8b..12d49be0f1 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,6 +1,5 @@ -from graphon.graph_events import GraphRunPausedEvent - from core.app.layers.suspend_layer import SuspendLayer +from graphon.graph_events import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 7cf6eb4f31..1ac9a4d8c0 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,8 +1,7 @@ from unittest.mock import Mock, patch -from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand - from core.app.layers.timeslice_layer import TimeSliceLayer +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index aa9285789b..d3bd15b6f3 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,11 +2,10 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch -from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent -from graphon.runtime import VariablePool - from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index 58aa7d7478..c246f7b783 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 4aaa10a81a..1c1bf391d3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -28,6 +26,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index f22602a400..a20d89d807 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,9 +5,6 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -41,6 +38,9 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 31b7313066..595d716666 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from graphon.file import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py index 29df7eea86..21c761c579 100644 --- a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import patch -from graphon.model_runtime.entities.model_entities import ModelPropertyKey - from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.entities import ModelConfigEntity +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index dc2d82ccd6..5c50cb78da 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType -from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index 7be9d6ac1e..701863b927 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -8,13 +8,13 @@ from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest -from graphon.file import File, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime from core.workflow.file_reference import build_file_reference +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile @@ -26,8 +26,8 @@ def _build_file( extension: str | None = None, ) -> File: return File( - id="file-id", - type=FileType.IMAGE, + file_id="file-id", + file_type=FileType.IMAGE, transfer_method=transfer_method, reference=reference, remote_url=remote_url, @@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 8497261d45..30a068f4c5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,15 +1,15 @@ from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory +from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): - self.id = id + def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = node_id self.config = config self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index a47d3db6f5..82552470a9 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,9 +2,8 @@ from __future__ import annotations from types import SimpleNamespace -from graphon.enums import BuiltinNodeTypes - from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerExtras: diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index d8a68f6d00..cacb4dd4fa 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,6 +4,10 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause from graphon.enums import ( @@ -29,10 +33,6 @@ from graphon.graph_events import ( from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.system_variables import SystemVariableKey, build_system_variables - class _RepoRecorder: def __init__(self) -> None: diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 5ff9774b52..7b433ab57b 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -301,6 +301,7 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -308,8 +309,6 @@ class TestAppGeneratorTTSPublisher: TextPromptMessageContent, ) - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -337,11 +336,10 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import AssistantPromptMessage - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d338cadb77..deeac49bbc 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", extension=".png", @@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", extension=".pdf", diff --git a/api/tests/unit_tests/core/datasource/test_notion_provider.py b/api/tests/unit_tests/core/datasource/test_notion_provider.py index e4bd7d3bdf..d21b9e471b 100644 --- a/api/tests/unit_tests/core/datasource/test_notion_provider.py +++ b/api/tests/unit_tests/core/datasource/test_notion_provider.py @@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime: "last_edited_time": "2024-11-27T18:00:00.000Z", } mock_request.return_value = mock_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act extractor_page.update_last_edited_time(mock_document_model) @@ -863,9 +860,6 @@ class TestNotionExtractorIntegration: } mock_request.side_effect = [last_edited_response, block_response] - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act documents = extractor.extract() @@ -919,10 +913,6 @@ class TestNotionExtractorIntegration: } mock_post.return_value = database_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query - # Act documents = extractor.extract() diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index fbaf6d497d..0fca43cd0b 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index ff9fd0d8f3..ef8f360dbf 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,12 +1,11 @@ -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType - from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 2acd278a31..aeca2e3afd 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,9 +8,6 @@ drive provider mapping behavior. """ import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -19,6 +16,9 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: @@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None: # Assert assert simple_provider.provider == "openai" - assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.label.en_us == "OpenAI" assert simple_provider.supported_model_types == [ModelType.LLM] diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 8cf0409c4c..a28143026f 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,17 +6,6 @@ from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -35,6 +24,17 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) ] ) - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="encrypted-old-key" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): - with patch( - "core.entities.provider_configuration.encrypter.encrypt_token", - side_effect=lambda tenant_id, value: f"enc::{value}", - ): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, - credential_id="credential-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + ) assert validated["openai_api_key"] == "enc::restored-key" assert validated["region"] == "us" @@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) -def test_validate_provider_credentials_opens_session_when_not_passed() -> None: +def test_validate_provider_credentials_without_credential_id() -> None: configuration = _build_provider_configuration() - mock_session = Mock() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.Session") as mock_session_cls: - with patch("core.entities.provider_configuration.db") as mock_db: - mock_db.engine = Mock() - mock_session_cls.return_value.__enter__.return_value = mock_session - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory - ): - validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} - mock_session_cls.assert_called_once() def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: @@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non def test_validate_provider_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-key"} @@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback( def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc"}' ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) - assert validated == {"openai_api_key": "enc-new"} - - session = Mock() - mock_factory = Mock() - mock_factory.model_credentials_validate.return_value = {"region": "us"} - with _patched_session(session): + with _patched_session(mock_session): with patch( "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory ): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"region": "us"}, - ) + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) + assert validated == {"openai_api_key": "enc-new"} + + mock_factory2 = Mock() + mock_factory2.model_credentials_validate.return_value = {"region": "us"} + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) assert validated == {"region": "us"} @@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = None + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} @@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index 8685d16283..a159d3ad4d 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -9,6 +8,7 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py index 216cda83c5..63e887f904 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_base.py +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from core.extension.extensible import ExtensionModule @@ -12,10 +14,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) @@ -28,10 +30,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") @@ -43,10 +45,10 @@ class TestExternalDataTool: def test_validate_config_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" with pytest.raises(NotImplementedError): @@ -55,10 +57,10 @@ class TestExternalDataTool: def test_query_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index bb6e40e224..8cb0938575 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType def test_file(): file = File( - id="test-file", + file_id="test-file", tenant_id="test-tenant-id", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", filename="image.png", @@ -25,27 +25,21 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_accepts_legacy_tenant_id(): - data = { - "id": "test-file", - "tenant_id": "test-tenant-id", - "type": "image", - "transfer_method": "tool_file", - "related_id": "test-related-id", - "filename": "image.png", - "extension": ".png", - "mime_type": "image/png", - "size": 67, - "storage_key": "test-storage-key", - "url": "https://example.com/image.png", - # Extra legacy fields - "tool_file_id": "tool-file-123", - "upload_file_id": "upload-file-456", - "datasource_file_id": "datasource-file-789", - } +def test_file_constructor_accepts_legacy_tenant_id(): + file = File( + file_id="test-file", + tenant_id="test-tenant-id", + file_type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + tool_file_id="tool-file-123", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) - file = File.model_validate(data) - - assert file.related_id == "test-related-id" + assert file.related_id == "tool-file-123" assert file.storage_key == "test-storage-key" assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/test_creators.py b/api/tests/unit_tests/core/helper/test_creators.py new file mode 100644 index 0000000000..df67d3f513 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_creators.py @@ -0,0 +1,106 @@ +"""Tests for the Creators Platform helper module.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from yarl import URL + + +@pytest.fixture(autouse=True) +def _patch_creators_url(monkeypatch): + """Patch the module-level creators_platform_api_url for all tests.""" + monkeypatch.setattr( + "core.helper.creators.creators_platform_api_url", + URL("https://creators.example.com"), + ) + + +class TestUploadDSL: + @patch("core.helper.creators.httpx.post") + def test_returns_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {"claim_code": "abc123"}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + result = upload_dsl(b"app: demo", "demo.yaml") + + assert result == "abc123" + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert "anonymous-upload" in call_kwargs.args[0] + assert call_kwargs.kwargs["timeout"] == 30 + + @patch("core.helper.creators.httpx.post") + def test_raises_on_missing_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(ValueError, match="claim_code"): + upload_dsl(b"app: demo") + + @patch("core.helper.creators.httpx.post") + def test_raises_on_http_error(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server Error", + request=MagicMock(), + response=MagicMock(), + ) + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(httpx.HTTPStatusError): + upload_dsl(b"app: demo") + + +class TestGetRedirectUrl: + @patch("core.helper.creators.dify_config") + def test_without_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code" not in url + assert url.startswith("https://creators.example.com") + + @patch("core.helper.creators.dify_config") + def test_with_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz" + + with patch( + "services.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="oauth-code-123", + ) as mock_sign: + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + mock_sign.assert_called_once_with("client-xyz", "user-1") + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code=oauth-code-123" in url + + @patch("core.helper.creators.dify_config") + def test_strips_trailing_slash(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert url.startswith("https://creators.example.com?") + assert "creators.example.com/?" not in url diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index f3ef7fccd0..73e081a570 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -40,11 +40,11 @@ class TestObfuscatedToken: class TestEncryptToken: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_successful_encryption(self, mock_encrypt, mock_query): + def test_successful_encryption(self, mock_encrypt, mock_get): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -53,9 +53,9 @@ class TestEncryptToken: mock_encrypt.assert_called_with("test_token", "mock_public_key") @patch("extensions.ext_database.db.session.get") - def test_tenant_not_found(self, mock_query): + def test_tenant_not_found(self, mock_get): """Test error when tenant doesn't exist""" - mock_query.return_value = None + mock_get.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") - def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): + def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get): """Test that encryption and decryption are consistent""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -148,12 +148,12 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_cross_tenant_isolation(self, mock_encrypt, mock_query): + def test_cross_tenant_isolation(self, mock_encrypt, mock_get): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -183,10 +183,10 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_encryption_randomness(self, mock_encrypt, mock_query): + def test_encryption_randomness(self, mock_encrypt, mock_get): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -207,11 +207,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): + def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -221,11 +221,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): + def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -244,11 +244,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): + def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/helper/test_moderation.py b/api/tests/unit_tests/core/helper/test_moderation.py index 4a84099b74..a0dfa86d20 100644 --- a/api/tests/unit_tests/core/helper/test_moderation.py +++ b/api/tests/unit_tests/core/helper/test_moderation.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from typing import cast import pytest -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from pytest_mock import MockerFixture from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.moderation import check_moderation +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.provider import ProviderType diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 3b5c5e6597..d9fed9ae2a 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,11 +1,17 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from core.helper.ssrf_proxy import ( SSRF_DEFAULT_MAX_RETRIES, + SSRFProxy, _get_user_provided_host_header, + _to_graphon_http_response, + graphon_ssrf_proxy, make_request, + max_retries_exceeded_error, + request_error, ) @@ -174,3 +180,56 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert call_kwargs.get("follow_redirects") is True + + +def test_to_graphon_http_response_preserves_httpx_response_fields() -> None: + response = httpx.Response( + 201, + headers={"X-Test": "1"}, + content=b"payload", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + wrapped = _to_graphon_http_response(response) + + assert wrapped.status_code == 201 + assert wrapped.headers == {"x-test": "1", "content-length": "7"} + assert wrapped.content == b"payload" + assert wrapped.url == "https://example.com/resource" + assert wrapped.reason_phrase == "Created" + assert wrapped.text == "payload" + + +def test_ssrf_proxy_exposes_expected_error_types() -> None: + proxy = SSRFProxy() + + assert proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert proxy.request_error is request_error + assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert graphon_ssrf_proxy.request_error is request_error + + +@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"]) +def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None: + response = httpx.Response( + 200, + headers={"X-Test": "1"}, + content=b"ok", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method: + wrapped = getattr(graphon_ssrf_proxy, method_name)( + "https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + + mock_method.assert_called_once_with( + url="https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + assert wrapped.status_code == 200 + assert wrapped.url == "https://example.com/resource" + assert wrapped.content == b"ok" diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index b45f6fd9a7..6ed9ddb476 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,20 +2,6 @@ import json from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultWithStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import ( @@ -30,6 +16,20 @@ from core.llm_generator.output_parser.structured_output import ( remove_additional_properties, ) from core.model_manager import ModelInstance +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType class TestStructuredOutput: diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 7cdfb31189..c4e610d5b0 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,12 +2,13 @@ import json from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @@ -96,6 +97,10 @@ class TestLLMGenerator: questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert len(questions) == 2 assert questions[0] == "Question 1?" + assert mock_model_instance.invoke_llm.call_args.kwargs["model_parameters"] == { + "max_tokens": 2560, + "temperature": 0.0, + } def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: @@ -113,6 +118,97 @@ class TestLLMGenerator: questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] + @patch("core.llm_generator.llm_generator.ModelManager.for_tenant") + def test_generate_suggested_questions_after_answer_with_custom_model_and_prompt(self, mock_for_tenant): + custom_model_instance = MagicMock() + custom_response = MagicMock() + custom_response.message.get_text_content.return_value = '["Question 1?"]' + custom_model_instance.invoke_llm.return_value = custom_response + + mock_for_tenant.return_value.get_model_instance.return_value = custom_model_instance + + questions = LLMGenerator.generate_suggested_questions_after_answer( + "tenant_id", + "histories", + instruction_prompt="custom prompt", + model_config={ + "provider": "openai", + "name": "gpt-4o", + "completion_params": {"temperature": 0.2}, + }, + ) + + assert questions == ["Question 1?"] + mock_for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id="tenant_id", + model_type=ModelType.LLM, + provider="openai", + model="gpt-4o", + ) + + invoke_kwargs = custom_model_instance.invoke_llm.call_args.kwargs + assert invoke_kwargs["model_parameters"] == {"temperature": 0.2} + assert invoke_kwargs["stop"] == [] + assert "custom prompt" in invoke_kwargs["prompt_messages"][0].content + + @patch("core.llm_generator.llm_generator.ModelManager.for_tenant") + def test_generate_suggested_questions_after_answer_fallback_to_default_model(self, mock_for_tenant): + default_model_instance = MagicMock() + default_response = MagicMock() + default_response.message.get_text_content.return_value = '["Question 1?"]' + default_model_instance.invoke_llm.return_value = default_response + + mock_for_tenant.return_value.get_model_instance.side_effect = ValueError("invalid configured model") + mock_for_tenant.return_value.get_default_model_instance.return_value = default_model_instance + + questions = LLMGenerator.generate_suggested_questions_after_answer( + "tenant_id", + "histories", + model_config={ + "provider": "openai", + "name": "not-found-model", + "completion_params": {"temperature": 0.2}, + }, + ) + + assert questions == ["Question 1?"] + mock_for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant_id", + model_type=ModelType.LLM, + ) + assert default_model_instance.invoke_llm.call_args.kwargs["model_parameters"] == { + "max_tokens": 2560, + "temperature": 0.0, + } + assert default_model_instance.invoke_llm.call_args.kwargs["stop"] == [] + + @patch("core.llm_generator.llm_generator.ModelManager.for_tenant") + def test_generate_suggested_questions_after_answer_drops_non_positive_max_tokens(self, mock_for_tenant): + custom_model_instance = MagicMock() + custom_response = MagicMock() + custom_response.message.get_text_content.return_value = '["Question 1?"]' + custom_model_instance.invoke_llm.return_value = custom_response + mock_for_tenant.return_value.get_model_instance.return_value = custom_model_instance + + questions = LLMGenerator.generate_suggested_questions_after_answer( + "tenant_id", + "histories", + model_config={ + "provider": "openai", + "name": "gpt-4o", + "completion_params": { + "temperature": 0.2, + "max_tokens": 0, + "stop": ["END"], + }, + }, + ) + + assert questions == ["Question 1?"] + invoke_kwargs = custom_model_instance.invoke_llm.call_args.kwargs + assert invoke_kwargs["model_parameters"] == {"temperature": 0.2} + assert invoke_kwargs["stop"] == ["END"] + def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity): payload = RuleGeneratePayload( instruction="test instruction", model_config=model_config_entity, no_variable=True @@ -395,7 +491,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}} @@ -421,7 +517,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() # Cause exception in node_type logic workflow.graph_dict = {"graph": {"nodes": []}} @@ -448,7 +544,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} @@ -536,7 +632,7 @@ class TestLLMGenerator: instance.invoke_llm.return_value = mock_response with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}} diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py index 81f8da9a62..bbbffa2e69 100644 --- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -971,6 +971,23 @@ class TestHandlePostRequestNew: assert isinstance(item, SessionMessage) assert isinstance(item.message.root, JSONRPCError) assert item.message.root.id == 77 + assert item.message.root.error.message == "Session terminated by server" + + def test_404_on_initialization_includes_url_in_error(self): + t = _new_transport(url="http://example.com/mcp/server/abc123/mcp") + q: queue.Queue = queue.Queue() + msg = _make_request_msg("initialize", 1) + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.error.code == 32600 + assert "404 Not Found" in item.message.root.error.message + assert "http://example.com/mcp/server/abc123/mcp" in item.message.root.error.message def test_404_for_notification_no_error_sent(self): t = _new_transport() diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 9a815fb94d..57456085c3 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,7 +3,6 @@ from unittest.mock import Mock, patch import jsonschema import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index 9a5fb319d7..f459250b8e 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,6 +4,8 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest + +from core.memory.token_buffer_memory import TokenBufferMemory from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -11,8 +13,6 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 6a672fdfd5..c4fd970562 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest + from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -12,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index 3a97ad5c5d..4c668ee96b 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -10,6 +10,7 @@ This module tests all aspects of the content moderation system including: - Configuration validation """ +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest @@ -28,7 +29,7 @@ class TestKeywordsModeration: """Test suite for custom keyword-based content moderation.""" @pytest.fixture - def keywords_config(self) -> dict: + def keywords_config(self) -> dict[str, Any]: """ Fixture providing a standard keywords moderation configuration. @@ -48,7 +49,7 @@ class TestKeywordsModeration: } @pytest.fixture - def keywords_moderation(self, keywords_config: dict) -> KeywordsModeration: + def keywords_moderation(self, keywords_config: dict[str, Any]) -> KeywordsModeration: """ Fixture providing a KeywordsModeration instance. @@ -64,7 +65,7 @@ class TestKeywordsModeration: config=keywords_config, ) - def test_validate_config_success(self, keywords_config: dict): + def test_validate_config_success(self, keywords_config: dict[str, Any]): """Test successful validation of keywords moderation configuration.""" # Should not raise any exception KeywordsModeration.validate_config("test-tenant", keywords_config) @@ -274,7 +275,7 @@ class TestOpenAIModeration: """Test suite for OpenAI-based content moderation.""" @pytest.fixture - def openai_config(self) -> dict: + def openai_config(self) -> dict[str, Any]: """ Fixture providing OpenAI moderation configuration. @@ -293,7 +294,7 @@ class TestOpenAIModeration: } @pytest.fixture - def openai_moderation(self, openai_config: dict) -> OpenAIModeration: + def openai_moderation(self, openai_config: dict[str, Any]) -> OpenAIModeration: """ Fixture providing an OpenAIModeration instance. @@ -309,7 +310,7 @@ class TestOpenAIModeration: config=openai_config, ) - def test_validate_config_success(self, openai_config: dict): + def test_validate_config_success(self, openai_config: dict[str, Any]): """Test successful validation of OpenAI moderation configuration.""" # Should not raise any exception OpenAIModeration.validate_config("test-tenant", openai_config) diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 2cbff54c42..69650c85cc 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -1,16 +1,11 @@ -import pytest -from pydantic import ValidationError +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_opik.config import OpikConfig +from dify_trace_weave.config import WeaveConfig -from core.ops.entities.config_entity import ( - AliyunConfig, - ArizeConfig, - LangfuseConfig, - LangSmithConfig, - OpikConfig, - PhoenixConfig, - TracingProviderEnum, - WeaveConfig, -) +from core.ops.entities.config_entity import TracingProviderEnum class TestTracingProviderEnum: @@ -27,349 +22,8 @@ class TestTracingProviderEnum: assert TracingProviderEnum.ALIYUN == "aliyun" -class TestArizeConfig: - """Test cases for ArizeConfig""" - - def test_valid_config(self): - """Test valid Arize configuration""" - config = ArizeConfig( - api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" - ) - assert config.api_key == "test_key" - assert config.space_id == "test_space" - assert config.project == "test_project" - assert config.endpoint == "https://custom.arize.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = ArizeConfig() - assert config.api_key is None - assert config.space_id is None - assert config.project is None - assert config.endpoint == "https://otlp.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = ArizeConfig(project="") - assert config.project == "default" - - def test_project_validation_none(self): - """Test project validation with None value""" - config = ArizeConfig(project=None) - assert config.project == "default" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = ArizeConfig(endpoint="") - assert config.endpoint == "https://otlp.arize.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") - assert config.endpoint == "https://custom.arize.com" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="ftp://invalid.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="invalid.com") - - -class TestPhoenixConfig: - """Test cases for PhoenixConfig""" - - def test_valid_config(self): - """Test valid Phoenix configuration""" - config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.phoenix.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = PhoenixConfig() - assert config.api_key is None - assert config.project is None - assert config.endpoint == "https://app.phoenix.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = PhoenixConfig(project="") - assert config.project == "default" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation with path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") - assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" - - def test_endpoint_validation_without_path(self): - """Test endpoint validation without path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") - assert config.endpoint == "https://app.phoenix.arize.com" - - -class TestLangfuseConfig: - """Test cases for LangfuseConfig""" - - def test_valid_config(self): - """Test valid Langfuse configuration""" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == "https://custom.langfuse.com" - - def test_valid_config_with_path(self): - host = "https://custom.langfuse.com/api/v1" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == host - - def test_default_values(self): - """Test default values are set correctly""" - config = LangfuseConfig(public_key="public", secret_key="secret") - assert config.host == "https://api.langfuse.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangfuseConfig() - - with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") - - with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") - - def test_host_validation_empty(self): - """Test host validation with empty value""" - config = LangfuseConfig(public_key="public", secret_key="secret", host="") - assert config.host == "https://api.langfuse.com" - - -class TestLangSmithConfig: - """Test cases for LangSmithConfig""" - - def test_valid_config(self): - """Test valid LangSmith configuration""" - config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.smith.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = LangSmithConfig(api_key="key", project="project") - assert config.endpoint == "https://api.smith.langchain.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangSmithConfig() - - with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") - - with pytest.raises(ValidationError): - LangSmithConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") - - -class TestOpikConfig: - """Test cases for OpikConfig""" - - def test_valid_config(self): - """Test valid Opik configuration""" - config = OpikConfig( - api_key="test_key", - project="test_project", - workspace="test_workspace", - url="https://custom.comet.com/opik/api/", - ) - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.workspace == "test_workspace" - assert config.url == "https://custom.comet.com/opik/api/" - - def test_default_values(self): - """Test default values are set correctly""" - config = OpikConfig() - assert config.api_key is None - assert config.project is None - assert config.workspace is None - assert config.url == "https://www.comet.com/opik/api/" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = OpikConfig(project="") - assert config.project == "Default Project" - - def test_url_validation_empty(self): - """Test URL validation with empty value""" - config = OpikConfig(url="") - assert config.url == "https://www.comet.com/opik/api/" - - def test_url_validation_missing_suffix(self): - """Test URL validation requires /api/ suffix""" - with pytest.raises(ValidationError, match="URL should end with /api/"): - OpikConfig(url="https://custom.comet.com/opik/") - - def test_url_validation_invalid_scheme(self): - """Test URL validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - OpikConfig(url="ftp://custom.comet.com/opik/api/") - - -class TestWeaveConfig: - """Test cases for WeaveConfig""" - - def test_valid_config(self): - """Test valid Weave configuration""" - config = WeaveConfig( - api_key="test_key", - entity="test_entity", - project="test_project", - endpoint="https://custom.wandb.ai", - host="https://custom.host.com", - ) - assert config.api_key == "test_key" - assert config.entity == "test_entity" - assert config.project == "test_project" - assert config.endpoint == "https://custom.wandb.ai" - assert config.host == "https://custom.host.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = WeaveConfig(api_key="key", project="project") - assert config.entity is None - assert config.endpoint == "https://trace.wandb.ai" - assert config.host is None - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - WeaveConfig() - - with pytest.raises(ValidationError): - WeaveConfig(api_key="key") - - with pytest.raises(ValidationError): - WeaveConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") - - def test_host_validation_optional(self): - """Test host validation is optional but validates when provided""" - config = WeaveConfig(api_key="key", project="project", host=None) - assert config.host is None - - config = WeaveConfig(api_key="key", project="project", host="") - assert config.host == "" - - config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") - assert config.host == "https://valid.host.com" - - def test_host_validation_invalid_scheme(self): - """Test host validation rejects invalid schemes when provided""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") - - -class TestAliyunConfig: - """Test cases for AliyunConfig""" - - def test_valid_config(self): - """Test valid Aliyun configuration""" - config = AliyunConfig( - app_name="test_app", - license_key="test_license_key", - endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", - ) - assert config.app_name == "test_app" - assert config.license_key == "test_license_key" - assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - assert config.app_name == "dify_app" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - AliyunConfig() - - with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") - - with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_app_name_validation_empty(self): - """Test app_name validation with empty value""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" - ) - assert config.app_name == "dify_app" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = AliyunConfig(license_key="test_license", endpoint="") - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation preserves path for Aliyun endpoints""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - ) - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_license_key_required(self): - """Test that license_key is required and cannot be empty""" - with pytest.raises(ValidationError): - AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_valid_endpoint_format_examples(self): - """Test valid endpoint format examples from comments""" - valid_endpoints = [ - # cms2.0 public endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", - # cms2.0 intranet endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", - # xtrace public endpoint - "http://tracing-cn-heyuan.arms.aliyuncs.com", - # xtrace intranet endpoint - "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", - ] - - for endpoint in valid_endpoints: - config = AliyunConfig(license_key="test_license", endpoint=endpoint) - assert config.endpoint == endpoint - - class TestConfigIntegration: - """Integration tests for configuration classes""" + """Cross-provider configuration sanity checks""" def test_all_configs_can_be_instantiated(self): """Test that all config classes can be instantiated with valid data""" @@ -388,7 +42,6 @@ class TestConfigIntegration: def test_url_normalization_consistency(self): """Test that URL normalization works consistently across configs""" - # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py index 543b278715..c24d3ac012 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import patch -from graphon.model_runtime.entities.message_entities import UserPromptMessage - from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.entities.request import RequestInvokeSummary +from graphon.model_runtime.entities.message_entities import UserPromptMessage def test_system_model_helpers_forward_user_id(): diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index f8d0e127b1..88bf555594 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -6,15 +6,15 @@ from types import SimpleNamespace from unittest.mock import Mock, sentinel import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_schema() -> AIModelEntity: @@ -56,7 +56,7 @@ class TestPluginModelRuntime: assert len(providers) == 1 assert providers[0].provider == "langgenius/openai/openai" assert providers[0].provider_name == "openai" - assert providers[0].label.en_US == "OpenAI" + assert providers[0].label.en_us == "OpenAI" client.fetch_model_providers.assert_called_once_with("tenant") def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index a812b01c5b..f1c4c7e700 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,12 +4,6 @@ from enum import StrEnum import pytest from flask import Response -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -31,6 +25,12 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index a3b1e5f6b0..704b82adc0 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,14 +17,6 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -45,6 +37,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 90730dff5a..00a4207786 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,12 +1,12 @@ from collections.abc import Generator import pytest -from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector +from graphon.file import File, FileTransferMethod, FileType class TestChunkMerger: @@ -466,7 +466,7 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): file_param = File( tenant_id="tenant-1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/file.png", storage_key="", @@ -499,14 +499,14 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): file_one = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/a.txt", storage_key="", ) file_two = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/b.txt", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 2b280dd674..e536c0831f 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,6 +2,13 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -11,13 +18,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg files = [ File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="", @@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files(): completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query(): memory = MagicMock(spec=TokenBufferMemory) prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 4a54649b28..28966242d8 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,19 +1,18 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index a4b3960b0a..5d865d934c 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,3 +1,5 @@ +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -7,9 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil - def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index e35ce2c48a..5308c8e7b3 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,16 +2,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform +from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle # from graphon.model_runtime.entities.message_entities import UserPromptMessage # from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule # from graphon.model_runtime.entities.provider_entities import ProviderEntity -# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 3f188cfbb4..0dc74b33df 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,12 +2,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -24,6 +18,12 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 006b4e7345..1f3247590c 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,13 +1,12 @@ from unittest.mock import MagicMock, patch -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError - from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError def _doc(content: str) -> Document: diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py index bbdd476914..1e91c2dd88 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -1,5 +1,6 @@ import json from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest @@ -28,15 +29,6 @@ class _Field: return ("in", self._name, tuple(values)) -class _FakeQuery: - def __init__(self): - self.where_calls: list[tuple] = [] - - def where(self, *conditions): - self.where_calls.append(conditions) - return self - - class _FakeExecuteResult: def __init__(self, segments: list[SimpleNamespace]): self._segments = segments @@ -57,7 +49,7 @@ class _FakeSelect: return self -def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict[str, Any] | None = None): return SimpleNamespace( data_source_type=data_source_type, keyword_table_dict=keyword_table_dict, diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 8b104597a8..b0ecad4d0c 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, call, patch from uuid import uuid4 @@ -20,7 +21,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -108,17 +109,6 @@ class _FakeExecuteResult: return _FakeExecuteScalarResult(self._data) -class _FakeSummaryQuery: - def __init__(self, summaries: list) -> None: - self._summaries = summaries - - def filter(self, *args, **kwargs): - return self - - def all(self) -> list: - return self._summaries - - class _FakeScalarsResult: def __init__(self, data: list) -> None: self._data = data diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index dc21d378a2..f84ce2771f 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -146,10 +146,7 @@ def test_get_vector_factory_entry_point_overrides_builtin(vector_factory_module, def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): dataset = SimpleNamespace(id="dataset-1") - with ( - patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), - patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), - ): + with patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"): default_vector = vector_factory_module.Vector(dataset) custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) @@ -166,10 +163,57 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): "original_chunk_id", ] assert custom_vector._attributes == ["doc_id"] - assert default_vector._embeddings == "embeddings" + # ``_embeddings`` is now a lazy proxy that defers materializing the real + # embedding model until ``embed_*`` is invoked, so cleanup paths never + # trigger billing/feature-service calls during ``Vector(dataset)`` + # construction. See ``_LazyEmbeddings``. + assert isinstance(default_vector._embeddings, vector_factory_module._LazyEmbeddings) assert default_vector._vector_processor == "processor" +def test_lazy_embeddings_defer_real_load_until_first_embed_call(vector_factory_module, monkeypatch): + """``Vector(dataset)`` must not transitively call ``ModelManager`` during + construction. The real embedding model should only be materialized on the + first ``embed_*`` call (i.e. create / search paths) so cleanup paths + (``delete_by_ids`` / ``delete``) remain resilient to billing-API failures. + """ + for_tenant_mock = MagicMock(side_effect=AssertionError("ModelManager.for_tenant must not be called eagerly")) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) + + dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + proxy = vector_factory_module._LazyEmbeddings(dataset) + + # Construction alone does not trigger ModelManager / FeatureService / BillingService. + for_tenant_mock.assert_not_called() + + # Exercising an embed_* method materializes the real model exactly once. + inner_model = MagicMock() + inner_model.embed_documents.return_value = [[0.1, 0.2]] + cached_embedding_mock = MagicMock(return_value=inner_model) + real_for_tenant = MagicMock() + real_for_tenant.get_model_instance.return_value = "embedding-model-instance" + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", MagicMock(return_value=real_for_tenant)) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", cached_embedding_mock) + + result = proxy.embed_documents(["hello"]) + + assert result == [[0.1, 0.2]] + cached_embedding_mock.assert_called_once_with("embedding-model-instance") + inner_model.embed_documents.assert_called_once_with(["hello"]) + + # Subsequent calls reuse the materialized model (no re-resolution). + inner_model.embed_documents.reset_mock() + cached_embedding_mock.reset_mock() + proxy.embed_documents(["world"]) + cached_embedding_mock.assert_not_called() + inner_model.embed_documents.assert_called_once_with(["world"]) + + def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): calls = {"vector_type": None, "init_args": None} @@ -372,19 +416,11 @@ def test_vector_delegation_methods(vector_factory_module): def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): - class _Field: - def __eq__(self, value): - return value - - upload_query = MagicMock() - upload_query.where.return_value = upload_query - vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) vector._embeddings = MagicMock() vector._vector_processor = MagicMock() mock_session = SimpleNamespace(get=lambda _model, _id: None) - monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session)) assert vector.search_by_file("file-1") == [] diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index 3563186186..051a1455ae 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 408cf14a51..4b8175b0b4 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,6 +49,10 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -56,10 +60,6 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 64eb89590a..0220fb6d4a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,12 +1,14 @@ """Primarily used for testing merged cell scenarios""" +import gc import io import os import tempfile +import warnings from collections import UserDict from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from docx import Document @@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): WordExtractor("not-a-file", "tenant", "user") -def test_del_closes_temp_file(): +def test_close_closes_temp_file(): extractor = object.__new__(WordExtractor) + extractor._closed = False extractor.temp_file = MagicMock() - WordExtractor.__del__(extractor) + extractor.close() extractor.temp_file.close.assert_called_once() +def test_close_is_idempotent(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + + extractor.close() + extractor.close() + + extractor.temp_file.close.assert_called_once() + + +def test_close_handles_async_close_mock(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + extractor.temp_file.close = AsyncMock() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + extractor.close() + gc.collect() + + extractor.temp_file.close.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] + + def test_extract_images_handles_invalid_external_cases(monkeypatch): class FakeTargetRef: def __contains__(self, item): diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index d4b987c832..4ba4d54fa0 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -1,15 +1,16 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: @@ -71,7 +72,9 @@ class TestParagraphIndexProcessor: with pytest.raises(ValueError, match="No rules found in process rule"): processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) - def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_validates_segmentation( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: rules_without_segmentation = SimpleNamespace(segmentation=None) with patch( @@ -84,7 +87,9 @@ class TestParagraphIndexProcessor: process_rule={"mode": "custom", "rules": {"enabled": True}}, ) - def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_builds_split_documents( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) splitter = Mock() splitter.split_documents.return_value = [ diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index b1b1835a52..bfae9001b7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch import pandas as pd @@ -77,7 +78,7 @@ class TestQAIndexProcessor: processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) def test_transform_preview_calls_formatter_once( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) split_node = Document(page_content=".question", metadata={}) @@ -119,7 +120,7 @@ class TestQAIndexProcessor: mock_format.assert_called_once() def test_transform_non_preview_uses_thread_batches( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: documents = [ Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 641c5d9ba0..b4bb343533 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,7 +53,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -64,6 +63,7 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument @@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk: mock_dependencies["redis"].get.return_value = None - # Mock database query for segment updates - mock_query = MagicMock() - mock_dependencies["db"].session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.update.return_value = None + # Mock database update for segment status + mock_dependencies["db"].session.execute.return_value = None # Create a proper context manager mock mock_context = MagicMock() diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index c279b00d3b..8bc7dbf70d 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,7 +17,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -29,6 +28,7 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 1b17cbc368..fd607210f1 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,13 +1,12 @@ import threading from contextlib import contextmanager, nullcontext from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest from flask import Flask, current_app -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelFeature from core.app.app_config.entities import ( DatasetEntity, @@ -34,6 +33,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset from models.enums import CreatorUserRole @@ -45,7 +46,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -2021,7 +2022,7 @@ def create_mock_document_methods( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -2416,12 +2417,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) @@ -4091,7 +4091,7 @@ def _doc( dataset_id: str = "dataset-1", document_id: str = "document-1", doc_id: str = "node-1", - extra: dict | None = None, + extra: dict[str, Any] | None = None, ) -> Document: metadata = { "score": score, diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 48782515d0..aace419d15 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 @@ -55,7 +56,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -450,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index 5a2ecb8220..43c521dcfd 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,8 +1,7 @@ from unittest.mock import Mock -from graphon.model_runtime.entities.llm_entities import LLMUsage - from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from graphon.model_runtime.entities.llm_entities import LLMUsage class TestFunctionCallMultiDatasetRouter: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index 539ac0f849..c56528cf55 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,13 +1,12 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.model_runtime.entities.model_entities import ModelType -from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish -from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter - class TestReactMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e229d5fc1a..3d3322094e 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,10 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 7dbf78d0f0..05b4f3a053 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 0fc82dda53..18ae9fafc8 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,11 +7,6 @@ from datetime import datetime from types import SimpleNamespace import pytest -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -19,13 +14,18 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 8ff0e40587..4248782d93 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,8 +9,6 @@ from typing import Any from unittest.mock import MagicMock import pytest -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( FormCreateParams, @@ -23,7 +21,7 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -31,6 +29,8 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index e5c3e85487..a08c5729cb 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 5b4d26b780..6af7b02d4c 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,12 +10,6 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker @@ -29,6 +23,12 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 84fe522388..abdbc72085 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 27729e7f06..5af1376a0a 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index ac65d0c02b..eab0176f41 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,15 +1,14 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig - from models.workflow import Workflow def test_file_to_dict(): file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="storage_key", diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index f5efb78b61..5a7e7e30a5 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis -from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.model_manager import LBModelManager +from core.model_manager import LBModelManager, ModelManager from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture @@ -40,6 +40,29 @@ def lb_model_manager(): return lb_model_manager +def test_model_manager_with_cache_enabled_reuses_stored_credentials(): + """With ``enable_credentials_cache=True``, later calls for the same key return cached creds.""" + provider_manager = MagicMock() + bundle = MagicMock() + bundle.configuration.provider.provider = "openai" + bundle.configuration.tenant_id = "tenant-1" + bundle.configuration.model_settings = [] + bundle.model_type_instance.model_type = ModelType.LLM + get_creds = MagicMock(return_value={"api_key": "first"}) + bundle.configuration.get_current_credentials = get_creds + provider_manager.get_provider_model_bundle.return_value = bundle + + manager = ModelManager(provider_manager, enable_credentials_cache=True) + first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert first.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + get_creds.return_value = {"api_key": "second"} + second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert second.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager): # initialize redis client redis_client.initialize(redis.Redis()) diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 331166fe63..b19a21d7f4 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,15 +1,6 @@ from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -21,6 +12,15 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index ee26172459..a5a542c94f 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -2,12 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel from models.provider_ids import ModelProviderID @@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration( provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) +def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls, + patch( + "core.provider_manager.ProviderConfiguration", + return_value=provider_configuration, + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is second + mock_get_all_providers.assert_called_once_with("tenant-id") + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_provider_configuration.assert_called_once() + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + provider_configuration_first = Mock() + provider_configuration_second = Mock() + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch( + "core.provider_manager.ProviderConfiguration", + side_effect=[provider_configuration_first, provider_configuration_second], + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + manager.clear_configurations_cache("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is not second + assert mock_get_all_providers.call_count == 2 + assert mock_provider_configuration.call_count == 2 + provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime) + provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): manager = _build_provider_manager(mocker) provider_configuration = Mock() diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index 5d744f88c9..1ff81f6120 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest -from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index ee0ce51eec..c7829fc0d7 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -6,8 +6,6 @@ from datetime import date from types import SimpleNamespace import pytest -from graphon.file import FileType -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -29,6 +27,8 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py index 79b8eaaa87..f35546b025 100644 --- a/api/tests/unit_tests/core/tools/test_custom_tool.py +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -1,6 +1,7 @@ from __future__ import annotations from types import SimpleNamespace +from typing import Any import httpx import pytest @@ -14,7 +15,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvo from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -def _build_tool(*, openapi: dict | None = None) -> ApiTool: +def _build_tool(*, openapi: dict[str, Any] | None = None) -> ApiTool: entity = ToolEntity( identity=ToolIdentity( author="author", diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index 2889cb9db1..ccffdf16d1 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,9 +12,9 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest -from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 8c0e7e9419..e13f430f9b 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -2,7 +2,7 @@ from __future__ import annotations from types import SimpleNamespace from typing import Any -from unittest.mock import PropertyMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -12,11 +12,13 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +# Create a mock class for testing abstract/base classes class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): return None +# Factory function to create a "lightweight" controller for testing def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController: controller = object.__new__(ApiToolProviderController) controller.provider_id = provider_id @@ -29,6 +31,7 @@ def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderContr return controller +# Test pure logic: filtering and deduplication def test_tool_label_manager_filter_tool_labels(): filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"]) assert set(filtered) == {"search", "news"} @@ -36,22 +39,68 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): + """ + Test the database update logic for tool labels. + Focus: Verify that labels are filtered, de-duplicated, and safely handled within a database session. + """ + # 1. Setup expected data from the controller controller = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: + expected_id = controller.provider_id + expected_type = controller.provider_type + + # 2. Patching External Dependencies + # - We patch 'db' to prevent Flask from trying to access a real database. + # - We patch 'sessionmaker' to intercept and control the creation of SQLAlchemy sessions. + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + # 3. Constructing the "Mocking Chain" + # In the business logic, we use: with sessionmaker(db.engine).begin() as _session: + # We need to link our 'mock_session' to the end of this complex context manager chain: + # Step A: sessionmaker(db.engine) -> returns an object (mock_sessionmaker.return_value) + # Step B: .begin() -> returns a context manager (begin.return_value) + # Step C: with ... as _session: -> calls __enter__(), and _session gets the __enter__.return_value + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # 4. Trigger the logic under test + # Input: ["search", "search", "invalid"] + # Logic: + # - "invalid" should be filtered out (not in default_tool_label_name_list). + # - The duplicate "search" should be merged (unique labels). ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - mock_db.session.execute.assert_called_once() - # only one valid unique label should be inserted. - assert mock_db.session.add.call_count == 1 - mock_db.session.commit.assert_called_once() + # 5. Behavior Assertion: DELETE operation + # Verify that the manager first attempts to clear existing labels for this specific tool. + # This ensures the update is idempotent. + mock_session.execute.assert_called_once() + + # 6. Behavior Assertion: INSERT operation + # Verify that only ONE valid label ("search") was added after filtering and deduplication. + # If call_count == 1, it proves filter_tool_labels() worked as expected. + assert mock_session.add.call_count == 1 + + # 7. State Assertion: Data Integrity & Isolation + # Inspect the actual object passed to session.add() to ensure it has correct properties. + # This confirms that the data isolation (tool_id + tool_type) we refactored is active. + call_args = mock_session.add.call_args + added_label = call_args[0][0] # Retrieve the ToolLabelBinding instance + + assert added_label.label_name == "search", "The label name should be 'search' after filtering." + assert added_label.tool_id == expected_id, "The tool_id must match the provider_id for correct binding." + assert added_label.tool_type == expected_type, "Isolation failed: tool_type must be verified during update." +# Test error handling def test_tool_label_manager_update_tool_labels_unsupported(): with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type] +# Test retrieval logic def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): + # Mocking a property (@property) using PropertyMock with patch.object( _ConcreteBuiltinToolProviderController, "tool_labels", @@ -62,29 +111,67 @@ def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"] api = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = ["search", "news"] - labels = ToolLabelManager.get_tool_labels(api) - assert labels == ["search", "news"] + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + # Inject mock data into the query result: session.scalars(stmt).all() + mock_session.scalars.return_value.all.return_value = ["search", "news"] + + labels = ToolLabelManager.get_tool_labels(api) + assert labels == ["search", "news"] + + +def test_tool_label_manager_get_tool_labels_unsupported(): + """ + Negative Test: Ensure get_tool_labels raises ValueError for unsupported controller types. + This protects the internal API contract against accidental regressions during refactoring. + """ + # Passing a generic object() which doesn't match Api, Workflow, or Builtin controllers. with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type] +# Test batch processing and mapping def test_tool_label_manager_get_tools_labels_batch(): assert ToolLabelManager.get_tools_labels([]) == {} api = _api_controller("api-1") wf = _workflow_controller("wf-1") + + # SimpleNamespace is a quick way to simulate SQLAlchemy row objects records = [ SimpleNamespace(tool_id="api-1", label_name="search"), SimpleNamespace(tool_id="api-1", label_name="news"), SimpleNamespace(tool_id="wf-1", label_name="utilities"), ] - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = records + + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # Simulating the batch query result + mock_session.scalars.return_value.all.return_value = records + labels = ToolLabelManager.get_tools_labels([api, wf]) + + # Verify the final dictionary mapping assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]} + +def test_tool_label_manager_get_tools_labels_unsupported(): + """ + Negative Test: Ensure get_tools_labels raises ValueError if the list contains + unsupported controller types, even alongside valid ones. + """ + api = _api_controller("api-1") + + # Passing a list with one valid controller and one invalid object() with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item] diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index 6454a5bcd1..5f34135af4 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest import core.tools.utils.message_transformer as mt @@ -13,7 +15,7 @@ class _FakeToolFile: class _FakeToolFileManager: """Fake ToolFileManager to capture the mimetype passed in.""" - last_call: dict | None = None + last_call: dict[str, Any] | None = None def __init__(self, *args, **kwargs): pass diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 52f262e1cf..44785f939c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -10,9 +10,12 @@ from __future__ import annotations from decimal import Decimal from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest + +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -22,10 +25,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils - -def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: +def _mock_model_instance(*, schema: dict[str, Any] | None = None) -> SimpleNamespace: model_type_instance = Mock() model_type_instance.get_model_schema.return_value = ( SimpleNamespace(model_properties=schema or {}) if schema is not None else None diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 40f91b12a0..032b1377a4 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -1,4 +1,5 @@ from json.decoder import JSONDecodeError +from typing import Any from unittest.mock import Mock, patch import pytest @@ -259,8 +260,8 @@ def test_parse_openapi_to_tool_bundle_server_env_and_refs(app): }, } - extra_info: dict = {} - warning: dict = {} + extra_info: dict[str, Any] = {} + warning: dict[str, Any] = {} with app.test_request_context(headers={"X-Request-Env": "prod"}): bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @@ -298,7 +299,7 @@ def test_parse_swagger_to_openapi_branches(): } ) - warning: dict = {"seed": True} + warning: dict[str, Any] = {"seed": True} converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi( { "servers": [{"url": "https://x"}], diff --git a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py index 5691f33e65..6bb86ebe78 100644 --- a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py @@ -2,50 +2,50 @@ from __future__ import annotations import pytest -from core.tools.utils import system_oauth_encryption as oauth_encryption -from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter +from core.tools.utils import system_encryption as encryption +from core.tools.utils.system_encryption import EncryptionError, SystemEncrypter -def test_system_oauth_encrypter_roundtrip(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_roundtrip(): + encrypter = SystemEncrypter(secret_key="test-secret") payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"} - encrypted = encrypter.encrypt_oauth_params(payload) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(payload) + decrypted = encrypter.decrypt_params(encrypted) assert encrypted assert dict(decrypted) == payload -def test_system_oauth_encrypter_decrypt_validates_input(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_decrypt_validates_input(): + encrypter = SystemEncrypter(secret_key="test-secret") with pytest.raises(ValueError, match="must be a string"): - encrypter.decrypt_oauth_params(123) # type: ignore[arg-type] + encrypter.decrypt_params(123) # type: ignore[arg-type] with pytest.raises(ValueError, match="cannot be empty"): - encrypter.decrypt_oauth_params("") + encrypter.decrypt_params("") -def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_raises_error_for_invalid_ciphertext(): + encrypter = SystemEncrypter(secret_key="test-secret") - with pytest.raises(OAuthEncryptionError, match="Decryption failed"): - encrypter.decrypt_oauth_params("not-base64") + with pytest.raises(EncryptionError, match="Decryption failed"): + encrypter.decrypt_params("not-base64") -def test_system_oauth_helpers_use_global_cached_instance(monkeypatch): - monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None) - monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret") +def test_system_helpers_use_global_cached_instance(monkeypatch): + monkeypatch.setattr(encryption, "_encrypter", None) + monkeypatch.setattr("core.tools.utils.system_encryption.dify_config.SECRET_KEY", "global-secret") - first = oauth_encryption.get_system_oauth_encrypter() - second = oauth_encryption.get_system_oauth_encrypter() + first = encryption.get_system_encrypter() + second = encryption.get_system_encrypter() assert first is second - encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"}) - assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"} + encrypted = encryption.encrypt_system_params({"k": "v"}) + assert encryption.decrypt_system_params(encrypted) == {"k": "v"} -def test_create_system_oauth_encrypter_factory(): - encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret") - assert isinstance(encrypter, SystemOAuthEncrypter) +def test_create_system_encrypter_factory(): + encrypter = encryption.create_system_encrypter(secret_key="factory-secret") + assert isinstance(encrypter, SystemEncrypter) diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 0e3a7e623a..43f3fbd5c9 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index 4767480a5a..5a585c609a 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,7 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -14,6 +13,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index c20edd7400..72a73dd936 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,7 +11,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -25,6 +24,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index 78622b78b6..fb7dc52838 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -8,10 +8,10 @@ and select_trigger_debug_events orchestrator. from __future__ import annotations from datetime import datetime +from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -27,10 +27,11 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent +from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID -def _make_poller_args(node_config: dict | None = None) -> dict: +def _make_poller_args(node_config: dict[str, Any] | None = None) -> dict[str, Any]: return { "tenant_id": "t1", "user_id": "u1", diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 7406b88270..9e07ea1b6d 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,23 +1,30 @@ import dataclasses +from typing import Annotated import orjson import pytest +from pydantic import BaseModel, Discriminator, Tag + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, Segment, - SegmentUnion, StringSegment, get_segment_discriminator, ) @@ -42,11 +49,26 @@ from graphon.variables.variables import ( StringVariable, Variable, ) -from pydantic import BaseModel +from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool +type SegmentUnion = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] + ), + Discriminator(get_segment_discriminator), +] def _build_variable_pool( @@ -123,7 +145,7 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, @@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad: assert restored == model def test_all_segments_serialization(self): - """Test serialization/deserialization of all segment types""" + """Test file-aware segment serialization through Dify's model boundary.""" # Create one instance of each segment type test_file = create_test_file() @@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Segments(segments=all_segments) json_str = model.model_dump_json() - loaded = _Segments.model_validate_json(json_str) + loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all segments are preserved assert len(loaded.segments) == len(all_segments) @@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad: assert loaded_segment.value == original.value def test_all_variables_serialization(self): - """Test serialization/deserialization of all variable types""" + """Test file-aware variable serialization through Dify's model boundary.""" # Create one instance of each variable type test_file = create_test_file() @@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Variables(variables=all_variables) json_str = model.model_dump_json() - loaded = _Variables.model_validate_json(json_str) + loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all variables are preserved assert len(loaded.variables) == len(all_variables) diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 37ecd2890b..d4e862220a 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,4 +1,5 @@ import pytest + from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 09254e17a3..317fe99d37 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Any import pytest + from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( @@ -34,7 +35,7 @@ def create_test_file( """Factory function to create File objects for testing.""" return File( tenant_id="test-tenant", - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 75b01bf42e..dae5e1ce98 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,4 +1,6 @@ import pytest +from pydantic import ValidationError + from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -10,7 +12,6 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase -from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 41627f5e0b..025d79b25d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,12 +5,13 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider +from graphon.enums import BuiltinNodeTypes + @pytest.fixture def memory_span_exporter(): @@ -61,9 +62,8 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from graphon.nodes.tool.entities import ToolNodeData - from core.tools.entities.tool_entities import ToolProviderType + from graphon.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 99d131737e..5d6667257f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -3,17 +3,16 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.app.workflow.layers.llm_quota import LLMQuotaLayer +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import CommandType from graphon.graph_events import NodeRunSucceededEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom -from core.app.workflow.layers.llm_quota import LLMQuotaLayer -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance - def _build_dify_context() -> DifyRunContext: return DifyRunContext( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 9cf72763ee..919f15efd0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 88989db856..9f3e3b00b9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -1,18 +1,18 @@ -""" -Mock node factory for testing workflows with third-party service dependencies. +"""Mock node factory for third-party-service workflow tests. -This module provides a MockNodeFactory that automatically detects and mocks nodes -requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). +The factory follows the same config adaptation path as production +`DifyNodeFactory.create_node()`, but swaps selected node classes for mock +implementations before instantiation. """ from typing import TYPE_CHECKING, Any +from core.workflow.human_input_adapter import adapt_node_config_for_graph +from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import DifyNodeFactory - from .test_mock_nodes import ( MockAgentNode, MockCodeNode, @@ -83,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory): :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) + node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = typed_node_config["id"] - # Create mock node instance mock_class = self._mock_node_types[node_type] + resolved_node_data = self._validate_resolved_node_data(mock_class, node_data) if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -105,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory): ) elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -121,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory): BuiltinNodeTypes.PARAMETER_EXTRACTOR, }: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -131,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory): ) else: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -141,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(typed_node_config) + return super().create_node(node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 8b7fbd1b30..f9819c47ec 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,6 +10,10 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -27,11 +31,6 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode - if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState @@ -56,13 +55,14 @@ class MockNodeMixin: def __init__( self, - id: str, - config: Mapping[str, Any], + node_id: str, + config: Any, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, **kwargs: Any, - ): + ) -> None: if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) @@ -97,7 +97,7 @@ class MockNodeMixin: kwargs.setdefault("message_transformer", MagicMock()) super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 8311a1e847..75bc6d05f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -23,14 +30,6 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -140,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( - id=start_config["id"], - config=start_config, + node_id=start_config["id"], + config=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -155,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, + node_id=human_a_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -165,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, + node_id=human_b_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -183,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor ) end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( - id=end_config["id"], - config=end_config, + node_id=end_config["id"], + config=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index b11f957677..7d23b63049 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,6 +19,11 @@ from functools import lru_cache from pathlib import Path from typing import Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -39,12 +44,6 @@ from graphon.variables import ( StringVariable, ) -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool - from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py index cbc920705c..1f4509af9a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -1,9 +1,8 @@ from unittest.mock import patch -from graphon.enums import BuiltinNodeTypes - from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from graphon.enums import BuiltinNodeTypes def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py index 59dd763b59..c86de7f6e6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -1,9 +1,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from graphon.model_runtime.entities.model_entities import ModelType - from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from graphon.model_runtime.entities.model_entities import ModelType def test_fetch_model_reuses_single_model_assembly(): diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 7195471eb6..76b4cd1ef4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,15 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.nodes.answer.entities import AnswerNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -67,20 +67,15 @@ def test_execute_answer(): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - } - node = AnswerNode( - id=str(uuid.uuid4()), + node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, + config=AnswerNodeData( + title="123", + type="answer", + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + ), ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 343bcd3919..ec4cef1955 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest + +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import get_node_type_classes_mapping - # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index b9371a34f4..ef0df55995 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,6 +1,7 @@ import types from collections.abc import Mapping +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -13,8 +14,6 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) -from core.workflow.node_factory import get_node_type_classes_mapping - def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index d155124c50..ce0c9b79c6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,3 +1,4 @@ +from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -8,8 +9,6 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType -from configs import dify_config - CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index fb03ae9998..d7ef781732 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: @@ -78,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=gp, graph_runtime_state=gs, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index a5026b40cf..be7cc073db 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,4 +1,8 @@ import pytest + +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -12,10 +16,6 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables - HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 4705b3f76e..2e89a2da3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config(): assert default_config["retry_config"]["max_retries"] == 7 -def test_get_default_config_with_malformed_http_request_config_raises_value_error(): - with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): +def test_get_default_config_with_malformed_http_request_config_raises_type_error(): + with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"): HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}) @@ -114,8 +114,8 @@ def _build_http_node( start_at=time.perf_counter(), ) return HttpRequestNode( - id="http-node", - config=node_config, + node_id="http-node", + config=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d16e1233ac..07430498e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,7 +1,6 @@ +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients - def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index a2cdbbf132..0659984c76 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -10,6 +10,28 @@ from typing import Any from unittest.mock import MagicMock import pytest +from pydantic import ValidationError + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + EmailRecipientType, + ExternalRecipient, + MemberRecipient, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.node_events import PauseRequestedEvent from graphon.node_events.node import StreamCompletedEvent @@ -28,28 +50,6 @@ from graphon.nodes.human_input.enums import ( ) from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool -from pydantic import ValidationError - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from core.workflow.human_input_compat import ( - DeliveryMethodType, - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - EmailRecipientType, - ExternalRecipient, - MemberRecipient, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now @@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): entity.status_value = HumanInputFormStatus.SUBMITTED +def _build_human_input_node( + *, + node_id: str, + node_data: HumanInputNodeData | Mapping[str, Any], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + runtime: DifyHumanInputNodeRuntime, +) -> HumanInputNode: + typed_node_data = ( + node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data) + ) + return HumanInputNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + runtime=runtime, + ) + + class TestDeliveryMethod: """Test DeliveryMethod entity.""" @@ -239,7 +259,7 @@ class TestUserAction: data[field_name] = value with pytest.raises(ValidationError) as exc_info: - UserAction(**data) + UserAction.model_validate(data) errors = exc_info.value.errors() assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) @@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent: form_repository = InMemoryHumanInputFormRepository() runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 52802c7ce1..4a9438b14f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,6 +1,9 @@ import datetime from types import SimpleNamespace +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -8,13 +11,10 @@ from graphon.graph_events import ( NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) +from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now @@ -26,6 +26,28 @@ class _FakeFormRepository: return self._form +def _create_human_input_node( + *, + config: dict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + repo: _FakeFormRepository, +) -> HumanInputNode: + node_data = ( + config["data"] + if isinstance(config["data"], HumanInputNodeData) + else HumanInputNodeData.model_validate(config["data"]) + ) + return HumanInputNode( + node_id=config["id"], + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + ) + + def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( @@ -81,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) @@ -146,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode: ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index bbfe350f7e..8ffce39cd6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -2,7 +2,10 @@ from collections.abc import Mapping from typing import Any import pytest + +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams +from graphon.nodes.iteration.entities import IterationNodeData from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode from graphon.runtime import ( @@ -11,8 +14,6 @@ from graphon.runtime import ( GraphRuntimeState, VariablePool, ) - -from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -44,17 +45,14 @@ def _build_iteration_node( ) -> IterationNode: init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( - id="iteration-node", - config={ - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration-node", "output"], - "start_node_id": start_node_id, - }, - }, + node_id="iteration-node", + config=IterationNodeData( + type="iteration", + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id=start_node_id, + ), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index f8802138b5..f254fc3d09 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,9 +3,6 @@ import uuid from unittest.mock import Mock import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -19,6 +16,9 @@ from core.workflow.nodes.knowledge_index.protocols import ( SummaryIndexServiceProtocol, ) from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -93,6 +93,25 @@ def sample_chunks(): } +def _build_node( + *, + node_id: str, + node_data: KnowledgeIndexNodeData | dict[str, object], + graph_init_params, + graph_runtime_state, +) -> KnowledgeIndexNode: + return KnowledgeIndexNode( + node_id=node_id, + config=( + node_data + if isinstance(node_data, KnowledgeIndexNodeData) + else KnowledgeIndexNodeData.model_validate(node_data) + ), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + class TestKnowledgeIndexNode: """ Test suite for KnowledgeIndexNode. @@ -115,9 +134,9 @@ class TestKnowledgeIndexNode: } # Act - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -143,9 +162,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -176,9 +195,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -212,9 +231,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -269,9 +288,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -332,9 +351,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -383,9 +402,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -440,9 +459,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -498,9 +517,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -536,9 +555,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -583,9 +602,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index ab64be59ad..e923ee761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,10 +3,6 @@ import uuid from unittest.mock import Mock import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -18,9 +14,17 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( SingleRetrievalConfig, ) from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import ( + KnowledgeRetrievalNode, + _normalize_metadata_filter_scalar, + _normalize_metadata_filter_sequence_item, +) from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -85,6 +89,12 @@ def sample_node_data(): ) +def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None: + assert _normalize_metadata_filter_scalar(3) == 3 + assert _normalize_metadata_filter_scalar(True) == "True" + assert _normalize_metadata_filter_sequence_item(4) == "4" + + class TestKnowledgeRetrievalNode: """ Test suite for KnowledgeRetrievalNode. @@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -470,8 +480,8 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -507,8 +517,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -562,8 +572,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -610,8 +620,8 @@ class TestFetchDatasetRetriever: mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -671,8 +681,8 @@ class TestFetchDatasetRetriever: node_id = str(uuid.uuid4()) config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index fdf1706765..388654f279 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,23 +1,42 @@ +from types import SimpleNamespace from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.entities import ListOperatorNodeData from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY - class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" + @staticmethod + def _build_node(*, config, graph_init_params, graph_runtime_state): + return ListOperatorNode( + node_id="test", + config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + @staticmethod + def _filter_by(comparison_operator: str, value: str) -> dict[str, object]: + return { + "enabled": True, + "conditions": [{"comparison_operator": comparison_operator, "value": value}], + } + @pytest.fixture def mock_graph_runtime_state(self): """Create mock GraphRuntimeState.""" mock_state = MagicMock(spec=GraphRuntimeState) mock_variable_pool = MagicMock() + mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value) mock_state.variable_pool = mock_variable_pool return mock_state @@ -45,9 +64,8 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable - return ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + return self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -64,9 +82,8 @@ class TestListOperatorNode: "limit": {"enabled": False}, } - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -109,9 +126,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=[]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -128,11 +144,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "contains", - "value": "app", - }, + "filter_by": self._filter_by("contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -140,9 +152,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -157,11 +168,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "not contains", - "value": "app", - }, + "filter_by": self._filter_by("not contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -169,9 +176,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -186,11 +192,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "5", - }, + "filter_by": self._filter_by(">", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -198,9 +200,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -226,9 +227,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -254,9 +254,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,9 +281,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -299,11 +297,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "3", - }, + "filter_by": self._filter_by(">", "3"), "order_by": { "enabled": True, "value": "desc", @@ -317,9 +311,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -341,9 +334,8 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -366,9 +358,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["first", "middle", "last"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,11 +375,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "start with", - "value": "app", - }, + "filter_by": self._filter_by("start with", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -396,9 +383,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -413,11 +399,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "end with", - "value": "le", - }, + "filter_by": self._filter_by("end with", "le"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -425,9 +407,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -442,11 +423,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "=", - "value": "5", - }, + "filter_by": self._filter_by("=", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -454,9 +431,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -471,11 +447,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "≠", - "value": "5", - }, + "filter_by": self._filter_by("≠", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -483,9 +455,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -511,9 +482,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index c784f805c0..212ad07bd3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,6 +1,8 @@ from unittest import mock import pytest + +from core.model_manager import ModelInstance from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities import ( ImagePromptMessageContent, @@ -33,8 +35,6 @@ from graphon.nodes.llm.exc import ( from graphon.runtime import VariablePool from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment -from core.model_manager import ModelInstance - def _build_model_schema( *, @@ -71,8 +71,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -95,6 +95,8 @@ def variable_pool() -> VariablePool: def _fetch_prompt_messages_with_mocked_content(content): variable_pool = VariablePool.empty() model_instance = mock.MagicMock(spec=ModelInstance) + model_schema = mock.MagicMock() + model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text" prompt_template = [ LLMNodeChatModelMessage( text="You are a classifier.", @@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( "graphon.nodes.llm.llm_utils.fetch_model_schema", - return_value=mock.MagicMock(features=[]), + return_value=model_schema, ), mock.patch( "graphon.nodes.llm.llm_utils.handle_list_messages", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 7841bf05ad..c707cf28cd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,6 +5,19 @@ from collections.abc import Sequence from unittest import mock import pytest + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -67,19 +80,6 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -140,8 +140,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -205,14 +205,10 @@ def llm_node( mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers(): def test_fetch_files_with_file_segment(): file = File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment(): def test_fetch_files_with_array_file_segment(): files = [ File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", storage_key="", ), File( - id="2", - type=FileType.IMAGE, + file_id="2", + file_type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", @@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/jpg", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: image_b64_data = base64.b64encode(image_raw_data).decode() mock_saved_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", extension=".png", @@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable file_saver=file_saver, file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, reasoning_format="separated", ) ) @@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=_build_prepared_llm_mock(), reasoning_format="separated", request_start_time=1.0, @@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors(): file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=model_instance, ) ) diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 1c362a0a03..8f8ec49f14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any import pytest + +from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -18,8 +20,6 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType -from factories.variable_factory import build_segment_with_type - @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index d86e0efe02..892f6cc586 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.base.entities import VariableSelector @@ -8,11 +10,31 @@ from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState from graphon.template_rendering import TemplateRenderError - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params +def _build_template_transform_node( + *, + node_data, + graph_init_params, + graph_runtime_state, + node_id: str = "test_node", + **kwargs, +) -> TemplateTransformNode: + typed_node_data = ( + node_data + if isinstance(node_data, TemplateTransformNodeData) + else TemplateTransformNodeData.model_validate(node_data) + ) + return TemplateTransformNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + **kwargs, + ) + + class TestTemplateTransformNode: """Comprehensive test suite for TemplateTransformNode.""" @@ -59,9 +81,8 @@ class TestTemplateTransformNode: def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -75,9 +96,8 @@ class TestTemplateTransformNode: def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -88,9 +108,8 @@ class TestTemplateTransformNode: def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -108,9 +127,8 @@ class TestTemplateTransformNode: } mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -143,9 +161,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() with pytest.raises(ValueError, match="max_output_length must be a positive integer"): - TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -170,9 +187,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -198,9 +214,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Value: " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -218,9 +233,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -238,9 +252,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -260,9 +273,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "1234567890" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -302,9 +314,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -375,8 +386,8 @@ class TestTemplateTransformNode: ) assert mapping == { - "node_123.var1": ["sys", "input1"], - "node_123.empty_selector": [], + "node_123.var1": ("sys", "input1"), + "node_123.empty_selector": (), } def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): @@ -409,9 +420,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a static message." - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -448,9 +458,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Total: $31.5" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -477,9 +486,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -507,9 +515,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index bd22a8e318..a846efbb43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -1,14 +1,15 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, ) from graphon.runtime import GraphRuntimeState - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 @@ -37,15 +38,13 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( - id="test_node", - config={ - "id": "test_node", - "data": { - "title": "Template Transform", - "variables": [], - "template": "hello", - }, - }, + node_id="test_node", + config=TemplateTransformNodeData( + title="Template Transform", + type="template-transform", + variables=[], + template="hello", + ), graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=MagicMock(), @@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri assert mapping == { "node_123.validated": ["sys", "input1"], - "node_123.raw": ["sys", "input2"], + "node_123.raw": ("sys", "input2"), } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index e11ebf6eb8..364408ead6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,15 @@ from collections.abc import Mapping import pytest -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_runtime import resolve_dify_run_context from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - }, - } - ) +def _build_node_config() -> dict[str, object]: + return { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + ), + } + + +def _build_node_data() -> _SampleNodeData: + return _build_node_config()["data"] # type: ignore[return-value] def test_node_hydrates_data_during_initialization(): @@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization(): init_params, runtime_state = _build_context(graph_config) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum(): ) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error(): def test_base_node_data_keeps_dict_style_access_compatibility(): - node_data = _SampleNodeData.model_validate( - { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - } - ) + node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar") assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" @@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): def test_node_hydration_preserves_compatibility_extra_fields(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) - node_config = NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - "compat_flag": True, - }, - } - ) + node_config = { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + compat_flag=True, + ), + } node = _SampleNode( - id="node-1", - config=node_config, + node_id="node-1", + config=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 555ff0c945..dd75b32593 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,23 +4,25 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, + _extract_text_from_file, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from graphon.variables import ArrayFileSegment +from graphon.variables import ArrayFileSegment, FileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params @@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - node_config = {"id": "test_node_id", "data": node_data.model_dump()} http_client = Mock() node = DocumentExtractorNode( - id="test_node_id", - config=node_config, + node_id="test_node_id", + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"] - mock_excel_instance.parse.side_effect = [df, Exception("Parse error")] + mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_mixed_content" @@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"] - mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")] + mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_all_bad_sheets" @@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert mock_excel_instance.parse.call_count == 2 +@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook")) +def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file): + with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"): + _extract_text_from_excel(b"broken") + + @patch("pandas.ExcelFile") def test_extract_text_from_excel_numeric_type_column(mock_excel_file): """Test extracting text from Excel file with numeric column names.""" @@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): assert expected_manual == result +@pytest.mark.parametrize( + ("extension", "mime_type"), + [ + (".xlsx", "text/plain"), + (None, "application/vnd.ms-excel"), + ], +) +def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type): + file = Mock(spec=File) + file.extension = extension + file.mime_type = mime_type + + with ( + patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", + ), + patch( + "graphon.nodes.document_extractor.node._extract_text_from_excel", + return_value="excel text", + ) as mock_extract, + ): + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + assert result == "excel text" + mock_extract.assert_called_once_with(b"excel") + + +def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): + file = Mock(spec=File) + file.extension = None + file.mime_type = None + + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"unknown", + ): + with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"): + _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + +def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_list = Mock(spec=ArrayFileSegment) + file_list.value = [Mock(spec=File)] + mock_graph_runtime_state.variable_pool.get.return_value = file_list + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("bad file"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "bad file" + + +def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("single file failed"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "single file failed" + + +def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + return_value="single file text", + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs == {"text": "single file text"} + + def _make_docx_zip(use_backslash: bool) -> bytes: """Helper to build a minimal in-memory DOCX zip. diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 1b14f0ab13..aa9a1360b0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,6 +3,11 @@ import uuid from unittest.mock import MagicMock, Mock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -11,14 +16,23 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params +def _build_if_else_node( + *, + node_data: IfElseNodeData | dict[str, object], + init_params, + graph_runtime_state, +) -> IfElseNode: + return IfElseNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + ) + + def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} @@ -61,9 +75,8 @@ def test_execute_if_else_result_true(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "and", @@ -104,13 +117,8 @@ def test_execute_if_else_result_true(): {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -155,9 +163,8 @@ def test_execute_if_else_result_false(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "or", @@ -174,13 +181,8 @@ def test_execute_if_else_result_false(): }, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -222,11 +224,6 @@ def test_array_file_contains_file_name(): ], ) - node_config = { - "id": "if-else", - "data": node_data.model_dump(), - } - # Create properly configured mock for graph_init_params graph_init_params = Mock() graph_init_params.workflow_id = "test_workflow" @@ -242,17 +239,12 @@ def test_array_file_contains_file_name(): } } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=graph_init_params, - graph_runtime_state=Mock(), - config=node_config, - ) + node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock()) node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", @@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition): "logical_operator": "and", "conditions": [condition.model_dump()], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() @@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions(): ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={ - "id": "if-else", - "data": node_data, - }, ) # Mock db.session.close() @@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure(): } ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d28c3e01e5..465a4c0ff4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -16,7 +18,14 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom + +def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: + return ListOperatorNode( + node_id="test_node_id", + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=MagicMock(), + ) @pytest.fixture @@ -35,10 +44,6 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData.model_validate(config) - node_config = { - "id": "test_node_id", - "data": node_data.model_dump(), - } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() graph_init_params.workflow_id = "test_workflow" @@ -54,12 +59,7 @@ def list_operator_node(): } } - node = ListOperatorNode( - id="test_node_id", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=MagicMock(), - ) + node = _build_list_operator_node(node_data, graph_init_params) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node @@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node): files = [ File( filename="image1.jpg", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", ), File( filename="document1.pdf", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", ), File( filename="image2.png", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", ), File( filename="audio1.mp3", - type=FileType.AUDIO, + file_type=FileType.AUDIO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", extension=".txt", @@ -156,7 +156,7 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, extension=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 833c303052..5655f80737 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables): inputs=user_inputs, ) - config = { - "id": "start", - "data": StartNodeData(title="Start", variables=variables).model_dump(), - } + node_data = StartNodeData(title="Start", variables=variables) graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, @@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables): ) return StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -109,7 +106,7 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"): + with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"): node._run() @@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot(): inputs={"profile": {"age": 20, "name": "Tom"}}, ) - config = { - "id": "start", - "data": StartNodeData( - title="Start", - variables=[ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ], - ).model_dump(), - } + node_data = StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 1587014802..284af68319 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,14 +8,15 @@ from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock import pytest + +from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment - -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only @@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode: runtime = _StubToolRuntime() node = ToolNode( - id="node-instance", - config=config, + node_id="node-instance", + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, @@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode: return node -def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: +def _collect_events(generator: Generator) -> list[Any]: events: list[Any] = [] try: while True: events.append(next(generator)) - except StopIteration as stop: - return events, stop.value + except StopIteration: + return events def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: @@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li node_id=tool_node._node_id, tool_runtime=ToolRuntimeHandle(raw=object()), ) - return _collect_events(generator) + events = _collect_events(generator) + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert completed_events + return events, completed_events[-1].node_run_result.llm_usage def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", @@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index c4dfc5a179..438af211f3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -6,11 +6,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType -from graphon.nodes.tool.exc import ToolRuntimeInvocationError -from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import VariablePool from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError @@ -22,6 +17,11 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 952e798430..e3b5e3b591 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,13 +1,12 @@ from collections.abc import Mapping -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState - from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -28,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": TRIGGER_PLUGIN_NODE_TYPE, - "title": "Trigger Event", - "plugin_id": "plugin-id", - "provider_id": "provider-id", - "event_name": "event-name", - "subscription_id": "subscription-id", - "plugin_unique_identifier": "plugin-unique-identifier", - "event_parameters": {}, - }, - } +def _build_node_data() -> TriggerEventNodeData: + return TriggerEventNodeData( + type=TRIGGER_PLUGIN_NODE_TYPE, + title="Trigger Event", + plugin_id="plugin-id", + provider_id="provider-id", + event_name="event-name", + subscription_id="subscription-id", + plugin_unique_identifier="plugin-unique-identifier", + event_parameters={}, ) def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index f1132af02b..617554ee17 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,5 +1,4 @@ import pytest -from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -7,6 +6,7 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) +from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index cccd3fb676..07d03bec05 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -6,12 +6,9 @@ to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" erro when passing files to downstream LLM nodes. """ +from typing import Any from unittest.mock import Mock, patch -from graphon.entities import GraphInitParams -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -21,6 +18,9 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_variable_pool @@ -30,11 +30,6 @@ def create_webhook_node( tenant_id: str = "test-tenant", ) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "webhook-node-1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="test-workflow", graph_config={}, @@ -56,8 +51,8 @@ def create_webhook_node( ) node = TriggerWebhookNode( - id="webhook-node-1", - config=node_config, + node_id="webhook-node-1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -66,10 +61,6 @@ def create_webhook_node( runtime_state.app_config = Mock() runtime_state.app_config.tenant_id = tenant_id - # Provide compatibility alias expected by node implementation - # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests - node.node_id = node.id - return node @@ -97,7 +88,7 @@ def create_test_file_dict( } -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="webhook-node-1", @@ -105,7 +96,7 @@ def build_webhook_variable_pool(inputs: dict) -> VariablePool: ) -def expected_factory_mapping(file_dict: dict) -> dict: +def expected_factory_mapping(file_dict: dict[str, Any]) -> dict[str, Any]: return {**file_dict, "upload_file_id": file_dict["related_id"]} diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 34c66a4f9f..b839490d3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,11 +1,7 @@ +from typing import Any from unittest.mock import patch import pytest -from graphon.entities import GraphInitParams -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import FileVariable, StringVariable from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE @@ -18,16 +14,16 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="1", graph_config={}, @@ -47,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) start_at=0, ) node = TriggerWebhookNode( - id="1", - config=node_config, + node_id="1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -56,13 +52,10 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) # Provide tenant_id for conversion path runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() - # Compatibility alias for some nodes referencing `self.node_id` - node.node_id = node.id - return node -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="1", @@ -224,7 +217,7 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="image.jpg", @@ -233,7 +226,7 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", filename="document.pdf", @@ -268,8 +261,19 @@ def test_webhook_node_run_with_file_params(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() @@ -283,7 +287,7 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="test.jpg", @@ -316,8 +320,19 @@ def test_webhook_node_run_mixed_parameters(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py new file mode 100644 index 0000000000..8b5fceeb37 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py @@ -0,0 +1,350 @@ +from types import SimpleNamespace + +import pytest +from pydantic import BaseModel + +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + adapt_human_input_node_data_for_graph, + adapt_node_config_for_graph, + adapt_node_data_for_graph, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert " Team") - html = EmailDeliveryConfig.render_markdown_body( - "**Hello** [mail](mailto:test@example.com)" - ) - - assert rendered == "Open https://example.com and use 42" - assert sanitized == "Hello alert(1) Team" - assert "Hello" in html - assert " + ) +} + +export default memo(CreateAppAttributionBootstrap) diff --git a/web/app/components/custom/custom-page/__tests__/index.spec.tsx b/web/app/components/custom/custom-page/__tests__/index.spec.tsx index 5514d21b50..1f3655a9f8 100644 --- a/web/app/components/custom/custom-page/__tests__/index.spec.tsx +++ b/web/app/components/custom/custom-page/__tests__/index.spec.tsx @@ -1,9 +1,10 @@ +import type { ReactElement } from 'react' import type { AppContextValue } from '@/context/app-context' -import type { SystemFeatures } from '@/types/feature' -import { render, screen } from '@testing-library/react' +import { screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { beforeEach, describe, expect, it, vi } from 'vitest' import { createMockProviderContextValue } from '@/__mocks__/provider-context' +import { renderWithSystemFeatures } from '@/__tests__/utils/mock-system-features' import { contactSalesUrl, defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' import { @@ -12,12 +13,19 @@ import { useAppContext, userProfilePlaceholder, } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' -import { defaultSystemFeatures } from '@/types/feature' import CustomPage from '../index' +const render = (ui: ReactElement) => renderWithSystemFeatures(ui, { + systemFeatures: { + branding: { + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + }, + }, +}) + const { mockToast } = vi.hoisted(() => { const mockToast = Object.assign(vi.fn(), { success: vi.fn(), @@ -44,17 +52,13 @@ vi.mock('@/context/app-context', async (importOriginal) => { useAppContext: vi.fn(), } }) -vi.mock('@/context/global-public-context', () => ({ - useGlobalPublicStore: vi.fn(), -})) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) const mockUseProviderContext = vi.mocked(useProviderContext) const mockUseModalContext = vi.mocked(useModalContext) const mockUseAppContext = vi.mocked(useAppContext) -const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) const createProviderContext = ({ enableBilling = false, @@ -93,15 +97,6 @@ const createAppContextValue = (): AppContextValue => ({ isValidatingCurrentWorkspace: false, }) -const createSystemFeatures = (): SystemFeatures => ({ - ...defaultSystemFeatures, - branding: { - ...defaultSystemFeatures.branding, - enabled: true, - workspace_logo: 'https://example.com/workspace-logo.png', - }, -}) - describe('CustomPage', () => { const setShowPricingModal = vi.fn() @@ -113,10 +108,6 @@ describe('CustomPage', () => { setShowPricingModal, } as unknown as ReturnType) mockUseAppContext.mockReturnValue(createAppContextValue()) - mockUseGlobalPublicStore.mockImplementation(selector => selector({ - systemFeatures: createSystemFeatures(), - setSystemFeatures: vi.fn(), - })) }) // Integration coverage for the page and its child custom brand section. diff --git a/web/app/components/custom/custom-page/index.tsx b/web/app/components/custom/custom-page/index.tsx index 38f651882d..cd85a91230 100644 --- a/web/app/components/custom/custom-page/index.tsx +++ b/web/app/components/custom/custom-page/index.tsx @@ -20,7 +20,7 @@ const CustomPage = () => {
{t('upgradeTip.title', { ns: 'custom' })}
{t('upgradeTip.des', { ns: 'custom' })}
-
setShowPricingModal()}>{t('upgradeBtn.encourageShort', { ns: 'billing' })}
+
setShowPricingModal()}>{t('upgradeBtn.encourageShort', { ns: 'billing' })}
)} diff --git a/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx b/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx index b54b800049..1ea03c7bc6 100644 --- a/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx @@ -1,5 +1,5 @@ -import Button from '@/app/components/base/button' -import { cn } from '@/utils/classnames' +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' import PoweredByBrand from './powered-by-brand' type ChatPreviewCardProps = { @@ -22,14 +22,14 @@ const ChatPreviewCard = ({
-
Chatflow App
+
Chatflow App
-
+
-
Hello! How can I assist you today?
+
Hello! How can I assist you today?
-
Talk to Dify
+
Talk to Dify
diff --git a/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx b/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx index 8a0feffbc4..51c0db53d7 100644 --- a/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx @@ -20,7 +20,7 @@ const PoweredByBrand = ({ return ( <> -
POWERED BY
+
POWERED BY
{previewLogo ? logo : } diff --git a/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx b/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx index 573c2e9c1c..b4f30827bc 100644 --- a/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx @@ -1,5 +1,5 @@ -import Button from '@/app/components/base/button' -import { cn } from '@/utils/classnames' +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' import PoweredByBrand from './powered-by-brand' type WorkflowPreviewCardProps = { @@ -22,14 +22,14 @@ const WorkflowPreviewCard = ({
-
Workflow App
+
Workflow App
-
RUN ONCE
-
RUN BATCH
+
RUN ONCE
+
RUN BATCH
@@ -44,7 +44,7 @@ const WorkflowPreviewCard = ({
diff --git a/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx b/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx index a2730ffd27..99cbc03b32 100644 --- a/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx +++ b/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx @@ -1,9 +1,10 @@ import type { ChangeEvent } from 'react' import type { AppContextValue } from '@/context/app-context' import type { SystemFeatures } from '@/types/feature' -import { act, renderHook } from '@testing-library/react' +import { act } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { createMockProviderContextValue } from '@/__mocks__/provider-context' +import { renderHookWithSystemFeatures } from '@/__tests__/utils/mock-system-features' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' import { defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' @@ -13,12 +14,22 @@ import { useAppContext, userProfilePlaceholder, } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' -import { defaultSystemFeatures } from '@/types/feature' import useWebAppBrand from '../use-web-app-brand' +let currentBrandingOverrides: Partial = {} +const renderHook = (callback: (props: Props) => Result) => + renderHookWithSystemFeatures(callback, { + systemFeatures: { + branding: { + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + ...currentBrandingOverrides, + }, + }, + }) + const { mockNotify, mockToast } = vi.hoisted(() => { const mockNotify = vi.fn() const mockToast = Object.assign(mockNotify, { @@ -33,7 +44,7 @@ const { mockNotify, mockToast } = vi.hoisted(() => { return { mockNotify, mockToast } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) vi.mock('@/service/common', () => ({ @@ -49,9 +60,6 @@ vi.mock('@/context/app-context', async (importOriginal) => { vi.mock('@/context/provider-context', () => ({ useProviderContext: vi.fn(), })) -vi.mock('@/context/global-public-context', () => ({ - useGlobalPublicStore: vi.fn(), -})) vi.mock('@/app/components/base/image-uploader/utils', () => ({ imageUpload: vi.fn(), getImageUploadErrorMessage: vi.fn(), @@ -60,7 +68,6 @@ vi.mock('@/app/components/base/image-uploader/utils', () => ({ const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace) const mockUseAppContext = vi.mocked(useAppContext) const mockUseProviderContext = vi.mocked(useProviderContext) -const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) const mockImageUpload = vi.mocked(imageUpload) const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage) @@ -80,16 +87,6 @@ const createProviderContext = ({ }) } -const createSystemFeatures = (brandingOverrides: Partial = {}): SystemFeatures => ({ - ...defaultSystemFeatures, - branding: { - ...defaultSystemFeatures.branding, - enabled: true, - workspace_logo: 'https://example.com/workspace-logo.png', - ...brandingOverrides, - }, -}) - const createAppContextValue = (overrides: Partial = {}): AppContextValue => { const { currentWorkspace: currentWorkspaceOverride, ...restOverrides } = overrides const workspaceOverrides: Partial = currentWorkspaceOverride ?? {} @@ -122,21 +119,16 @@ const createAppContextValue = (overrides: Partial = {}): AppCon describe('useWebAppBrand', () => { let appContextValue: AppContextValue - let systemFeatures: SystemFeatures beforeEach(() => { vi.clearAllMocks() appContextValue = createAppContextValue() - systemFeatures = createSystemFeatures() + currentBrandingOverrides = {} mockUpdateCurrentWorkspace.mockResolvedValue(appContextValue.currentWorkspace) mockUseAppContext.mockImplementation(() => appContextValue) mockUseProviderContext.mockReturnValue(createProviderContext()) - mockUseGlobalPublicStore.mockImplementation(selector => selector({ - systemFeatures, - setSystemFeatures: vi.fn(), - })) mockGetImageUploadErrorMessage.mockReturnValue('upload error') }) @@ -174,10 +166,7 @@ describe('useWebAppBrand', () => { }) it('should fall back to an empty workspace logo when branding is disabled', () => { - systemFeatures = createSystemFeatures({ - enabled: false, - workspace_logo: '', - }) + currentBrandingOverrides = { enabled: false, workspace_logo: '' } const { result } = renderHook(() => useWebAppBrand()) diff --git a/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts b/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts index 44810de1fd..e24edab421 100644 --- a/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts +++ b/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts @@ -1,13 +1,14 @@ import type { ChangeEvent } from 'react' +import { toast } from '@langgenius/dify-ui/toast' +import { useSuspenseQuery } from '@tanstack/react-query' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' -import { toast } from '@/app/components/base/ui/toast' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' +import { systemFeaturesQueryOptions } from '@/service/system-features' const MAX_LOGO_FILE_SIZE = 5 * 1024 * 1024 const CUSTOM_CONFIG_URL = '/workspaces/custom-config' @@ -19,7 +20,7 @@ const useWebAppBrand = () => { const [fileId, setFileId] = useState('') const [imgKey, setImgKey] = useState(() => Date.now()) const [uploadProgress, setUploadProgress] = useState(0) - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) + const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const isSandbox = enableBilling && plan.type === Plan.sandbox const uploading = uploadProgress > 0 && uploadProgress < 100 const webappLogo = currentWorkspace.custom_config?.replace_webapp_logo || '' diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx index 0df5e043e9..53e6ff36ec 100644 --- a/web/app/components/custom/custom-web-app-brand/index.tsx +++ b/web/app/components/custom/custom-web-app-brand/index.tsx @@ -1,8 +1,8 @@ +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' +import { Switch } from '@langgenius/dify-ui/switch' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import Switch from '@/app/components/base/switch' -import { cn } from '@/utils/classnames' import ChatPreviewCard from './components/chat-preview-card' import WorkflowPreviewCard from './components/workflow-preview-card' import useWebAppBrand from './hooks/use-web-app-brand' @@ -31,19 +31,19 @@ const CustomWebAppBrand = () => { return (
-
+
{t('webapp.removeBrand', { ns: 'custom' })}
-
{t('webapp.changeLogo', { ns: 'custom' })}
-
{t('webapp.changeLogoTip', { ns: 'custom' })}
+
{t('webapp.changeLogo', { ns: 'custom' })}
+
{t('webapp.changeLogoTip', { ns: 'custom' })}
{(!uploadDisabled && webappLogo && !webappBrandRemoved) && ( @@ -64,7 +64,7 @@ const CustomWebAppBrand = () => { className="relative mr-2" disabled={uploadDisabled} > - + { (webappLogo || fileId) ? t('change', { ns: 'custom' }) @@ -87,7 +87,7 @@ const CustomWebAppBrand = () => { className="relative mr-2" disabled={true} > - + {t('uploading', { ns: 'custom' })} ) @@ -118,8 +118,8 @@ const CustomWebAppBrand = () => { {uploadProgress === -1 && (
{t('uploadedFail', { ns: 'custom' })}
)} -
-
{t('overview.appInfo.preview', { ns: 'appOverview' })}
+
+
{t('overview.appInfo.preview', { ns: 'appOverview' })}
diff --git a/web/app/components/datasets/chunk.tsx b/web/app/components/datasets/chunk.tsx index e0d820a4f3..947f2859c4 100644 --- a/web/app/components/datasets/chunk.tsx +++ b/web/app/components/datasets/chunk.tsx @@ -50,11 +50,11 @@ export const QAPreview: FC = (props) => { return (
- +

{qa.question}

- +

{qa.answer}

diff --git a/web/app/components/datasets/common/credential-icon.tsx b/web/app/components/datasets/common/credential-icon.tsx index a5eb6bfdb6..f993eca585 100644 --- a/web/app/components/datasets/common/credential-icon.tsx +++ b/web/app/components/datasets/common/credential-icon.tsx @@ -1,6 +1,6 @@ +import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' -import { cn } from '@/utils/classnames' type CredentialIconProps = { avatarUrl?: string @@ -59,7 +59,7 @@ export const CredentialIcon: React.FC = ({ )} style={{ width: `${size}px`, height: `${size}px` }} > - + {firstLetter}
diff --git a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx index f8f0ce6e12..1251eab9fb 100644 --- a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx @@ -5,34 +5,7 @@ import * as React from 'react' import { ChunkingMode, DataSourceType } from '@/models/datasets' import DocumentPicker from '../index' -// Mock portal-to-follow-elem - always render content for testing -vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open }: { - children: React.ReactNode - open?: boolean - }) => ( -
- {children} -
- ), - PortalToFollowElemTrigger: ({ children, onClick }: { - children: React.ReactNode - onClick?: () => void - }) => ( -
- {children} -
- ), - // Always render content to allow testing document selection - PortalToFollowElemContent: ({ children, className }: { - children: React.ReactNode - className?: string - }) => ( -
- {children} -
- ), -})) +vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover')) // Mock useDocumentList hook with controllable return value let mockDocumentListData: { data: SimpleDocumentDetail[] } | undefined @@ -152,6 +125,10 @@ const renderComponent = (props: Partial { + fireEvent.click(screen.getByTestId('popover-trigger')) +} + describe('DocumentPicker', () => { beforeEach(() => { vi.clearAllMocks() @@ -165,7 +142,7 @@ describe('DocumentPicker', () => { it('should render without crashing', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render document name when provided', () => { @@ -273,7 +250,7 @@ describe('DocumentPicker', () => { onChange, }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle value with all fields', () => { @@ -318,13 +295,13 @@ describe('DocumentPicker', () => { it('should initialize with popup closed', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + expect(screen.getByTestId('popover')).toHaveAttribute('data-open', 'false') }) it('should open popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) // Verify click handler is called @@ -430,7 +407,7 @@ describe('DocumentPicker', () => { ) // The component should use the new callback - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should memoize handleChange callback with useCallback', () => { @@ -440,7 +417,7 @@ describe('DocumentPicker', () => { renderComponent({ onChange }) // Verify component renders correctly, callback memoization is internal - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -518,7 +495,7 @@ describe('DocumentPicker', () => { it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) // Trigger click should be handled @@ -591,7 +568,7 @@ describe('DocumentPicker', () => { renderComponent() // When loading, component should still render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should fetch documents on mount', () => { @@ -611,7 +588,7 @@ describe('DocumentPicker', () => { renderComponent() // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle undefined data response', () => { @@ -620,7 +597,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -732,13 +709,13 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle rapid toggle clicks', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') // Rapid clicks fireEvent.click(trigger) @@ -795,7 +772,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle document list mapping with various data_source_detail_dict states', () => { @@ -819,7 +796,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash during mapping - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -829,13 +806,13 @@ describe('DocumentPicker', () => { it('should handle empty datasetId', () => { renderComponent({ datasetId: '' }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle UUID format datasetId', () => { renderComponent({ datasetId: '123e4567-e89b-12d3-a456-426614174000' }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -926,6 +903,7 @@ describe('DocumentPicker', () => { const onChange = vi.fn() renderComponent({ onChange }) + openPopover() fireEvent.click(screen.getByText('Document 2')) @@ -939,6 +917,7 @@ describe('DocumentPicker', () => { mockDocumentListData = { data: docs } renderComponent() + openPopover() // Documents should be rendered in the list expect(screen.getByText('Document 1')).toBeInTheDocument() @@ -978,14 +957,14 @@ describe('DocumentPicker', () => { // The mapping: d.data_source_detail_dict?.upload_file?.extension || '' // Should extract 'pdf' from the document - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render trigger with SearchInput integration', () => { renderComponent() // The trigger is always rendered - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + expect(screen.getByTestId('popover-trigger')).toBeInTheDocument() }) it('should integrate FileIcon component', () => { @@ -1001,7 +980,7 @@ describe('DocumentPicker', () => { }) // FileIcon should render an SVG icon for the file extension - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -1010,9 +989,10 @@ describe('DocumentPicker', () => { describe('Visual States', () => { it('should render portal content for document selection', () => { renderComponent() + openPopover() - // Portal content is rendered in our mock for testing - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Popover content is rendered after opening the trigger in our mock + expect(screen.getByTestId('popover-content')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx b/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx index 7178e9f60c..c7eb2c740c 100644 --- a/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx +++ b/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx @@ -3,34 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import PreviewDocumentPicker from '../preview-document-picker' -// Mock portal-to-follow-elem - always render content for testing -vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open }: { - children: React.ReactNode - open?: boolean - }) => ( -
- {children} -
- ), - PortalToFollowElemTrigger: ({ children, onClick }: { - children: React.ReactNode - onClick?: () => void - }) => ( -
- {children} -
- ), - // Always render content to allow testing document selection - PortalToFollowElemContent: ({ children, className }: { - children: React.ReactNode - className?: string - }) => ( -
- {children} -
- ), -})) +vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover')) // Factory function to create mock DocumentItem const createMockDocumentItem = (overrides: Partial = {}): DocumentItem => ({ @@ -67,6 +40,10 @@ const renderComponent = (props: Partial { + fireEvent.click(screen.getByTestId('popover-trigger')) +} + describe('PreviewDocumentPicker', () => { beforeEach(() => { vi.clearAllMocks() @@ -77,7 +54,7 @@ describe('PreviewDocumentPicker', () => { it('should render without crashing', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render document name from value prop', () => { @@ -110,7 +87,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) @@ -120,7 +97,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -131,22 +108,21 @@ describe('PreviewDocumentPicker', () => { const props = createDefaultProps() render() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should apply className to trigger element', () => { renderComponent({ className: 'custom-class' }) - const trigger = screen.getByTestId('portal-trigger') - const innerDiv = trigger.querySelector('.custom-class') - expect(innerDiv).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('custom-class') }) it('should handle empty files array', () => { // Component should render without crashing with empty files renderComponent({ files: [] }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle single file', () => { @@ -155,7 +131,7 @@ describe('PreviewDocumentPicker', () => { files: [createMockDocumentItem({ id: 'single-doc', name: 'Single File' })], }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle multiple files', () => { @@ -164,7 +140,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(5), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should use value.extension for file icon', () => { @@ -172,7 +148,7 @@ describe('PreviewDocumentPicker', () => { value: createMockDocumentItem({ name: 'test.docx', extension: 'docx' }), }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -182,13 +158,13 @@ describe('PreviewDocumentPicker', () => { it('should initialize with popup closed', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + expect(screen.getByTestId('popover')).toHaveAttribute('data-open', 'false') }) it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) expect(trigger).toBeInTheDocument() @@ -196,9 +172,10 @@ describe('PreviewDocumentPicker', () => { it('should render portal content for document selection', () => { renderComponent() + openPopover() - // Portal content is always rendered in our mock for testing - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Popover content is rendered after opening the trigger in our mock + expect(screen.getByTestId('popover-content')).toBeInTheDocument() }) }) @@ -242,7 +219,7 @@ describe('PreviewDocumentPicker', () => { , ) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -265,7 +242,7 @@ describe('PreviewDocumentPicker', () => { , ) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -274,7 +251,7 @@ describe('PreviewDocumentPicker', () => { it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) expect(trigger).toBeInTheDocument() @@ -283,6 +260,7 @@ describe('PreviewDocumentPicker', () => { it('should render document list with files', () => { const files = createMockDocumentList(3) renderComponent({ files }) + openPopover() // Documents should be visible in the list expect(screen.getByText('Document 1')).toBeInTheDocument() @@ -295,6 +273,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 2')) @@ -306,7 +285,7 @@ describe('PreviewDocumentPicker', () => { it('should handle rapid toggle clicks', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') // Rapid clicks fireEvent.click(trigger) @@ -337,14 +316,14 @@ describe('PreviewDocumentPicker', () => { // Renders placeholder for missing name expect(screen.getByText('--')).toBeInTheDocument() // Portal wrapper renders - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle empty files array', () => { renderComponent({ files: [] }) // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle very long document names', () => { @@ -374,7 +353,7 @@ describe('PreviewDocumentPicker', () => { render() // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle large number of files', () => { @@ -382,7 +361,7 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files: manyFiles }) // Component should accept large files array - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle files with same name but different extensions', () => { @@ -393,7 +372,7 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files }) // Component should handle duplicate names - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -427,7 +406,7 @@ describe('PreviewDocumentPicker', () => { files: [createMockDocumentItem({ name: 'Single' })], }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle two files', () => { @@ -435,7 +414,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(2), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle many files', () => { @@ -443,7 +422,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(50), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -451,23 +430,22 @@ describe('PreviewDocumentPicker', () => { it('should apply custom className', () => { renderComponent({ className: 'my-custom-class' }) - const trigger = screen.getByTestId('portal-trigger') - expect(trigger.querySelector('.my-custom-class')).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('my-custom-class') }) it('should work without className', () => { renderComponent({ className: undefined }) - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + expect(screen.getByTestId('popover-trigger')).toBeInTheDocument() }) it('should handle multiple class names', () => { renderComponent({ className: 'class-one class-two' }) - const trigger = screen.getByTestId('portal-trigger') - const element = trigger.querySelector('.class-one') - expect(element).toBeInTheDocument() - expect(element).toHaveClass('class-two') + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('class-one') + expect(trigger).toHaveClass('class-two') }) }) @@ -480,7 +458,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -491,6 +469,7 @@ describe('PreviewDocumentPicker', () => { it('should render all documents in the list', () => { const files = createMockDocumentList(5) renderComponent({ files }) + openPopover() // All documents should be visible files.forEach((file) => { @@ -503,6 +482,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 1')) @@ -528,6 +508,7 @@ describe('PreviewDocumentPicker', () => { onChange={vi.fn()} />, ) + openPopover() expect(screen.getByText(/dataset\.preprocessDocument/)).toBeInTheDocument() }) }) @@ -537,9 +518,8 @@ describe('PreviewDocumentPicker', () => { it('should apply hover styles on trigger', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') - const innerDiv = trigger.querySelector('.hover\\:bg-state-base-hover') - expect(innerDiv).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('hover:bg-state-base-hover') }) it('should have truncate class for long names', () => { @@ -568,6 +548,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 1')) @@ -582,10 +563,12 @@ describe('PreviewDocumentPicker', () => { ] renderComponent({ files: customFiles, onChange }) + openPopover() fireEvent.click(screen.getByText('Custom File 1')) expect(onChange).toHaveBeenCalledWith(customFiles[0]) + openPopover() fireEvent.click(screen.getByText('Custom File 2')) expect(onChange).toHaveBeenCalledWith(customFiles[1]) }) @@ -597,8 +580,11 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files, onChange }) // Select multiple documents sequentially + openPopover() fireEvent.click(screen.getByText('Document 1')) + openPopover() fireEvent.click(screen.getByText('Document 3')) + openPopover() fireEvent.click(screen.getByText('Document 2')) expect(onChange).toHaveBeenCalledTimes(3) diff --git a/web/app/components/datasets/common/document-picker/document-list.tsx b/web/app/components/datasets/common/document-picker/document-list.tsx index 574792ee14..d2d8d1966c 100644 --- a/web/app/components/datasets/common/document-picker/document-list.tsx +++ b/web/app/components/datasets/common/document-picker/document-list.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' import type { DocumentItem } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useCallback } from 'react' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' type Props = { diff --git a/web/app/components/datasets/common/document-picker/index.tsx b/web/app/components/datasets/common/document-picker/index.tsx index b564b4d2b1..0566b590de 100644 --- a/web/app/components/datasets/common/document-picker/index.tsx +++ b/web/app/components/datasets/common/document-picker/index.tsx @@ -1,6 +1,12 @@ 'use client' import type { FC } from 'react' import type { DocumentItem, ParentMode, SimpleDocumentDetail } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@langgenius/dify-ui/popover' import { RiArrowDownSLine } from '@remixicon/react' import { useBoolean } from 'ahooks' import * as React from 'react' @@ -8,15 +14,9 @@ import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { GeneralChunk, ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge' import Loading from '@/app/components/base/loading' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' import SearchInput from '@/app/components/base/search-input' import { ChunkingMode } from '@/models/datasets' import { useDocumentList } from '@/service/knowledge/use-document' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' import DocumentList from './document-list' @@ -61,7 +61,6 @@ const DocumentPicker: FC = ({ const [open, { set: setOpen, - toggle: togglePopup, }] = useBoolean(false) const ArrowIcon = RiArrowDownSLine @@ -77,34 +76,40 @@ const DocumentPicker: FC = ({ }, [parentMode, t]) return ( - - -
- -
-
- - {' '} - {name || '--'} - - -
-
- - - {isGeneralMode && t('chunkingMode.general', { ns: 'dataset' })} - {isQAMode && t('chunkingMode.qa', { ns: 'dataset' })} - {isParentChild && `${t('chunkingMode.parentChild', { ns: 'dataset' })} · ${parentModeLabel}`} - + + +
+
+ + {' '} + {name || '--'} + + +
+
+ + + {isGeneralMode && t('chunkingMode.general', { ns: 'dataset' })} + {isQAMode && t('chunkingMode.qa', { ns: 'dataset' })} + {isParentChild && `${t('chunkingMode.parentChild', { ns: 'dataset' })} · ${parentModeLabel}`} + +
-
- - + )} + /> +
{documentsList @@ -125,9 +130,8 @@ const DocumentPicker: FC = ({
)}
- - -
+ + ) } export default React.memo(DocumentPicker) diff --git a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx index dea0693983..597ceda9a5 100644 --- a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx +++ b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx @@ -1,18 +1,18 @@ 'use client' import type { FC } from 'react' import type { DocumentItem } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@langgenius/dify-ui/popover' import { RiArrowDownSLine } from '@remixicon/react' import { useBoolean } from 'ahooks' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' import DocumentList from './document-list' @@ -35,7 +35,6 @@ const PreviewDocumentPicker: FC = ({ const [open, { set: setOpen, - toggle: togglePopup, }] = useBoolean(false) const ArrowIcon = RiArrowDownSLine @@ -45,29 +44,34 @@ const PreviewDocumentPicker: FC = ({ }, [onChange, setOpen]) return ( - - -
- -
-
- - {' '} - {name || '--'} - - + + +
+
+ + {' '} + {name || '--'} + + +
-
- - + )} + /> +
- {files?.length > 1 &&
{t('preprocessDocument', { ns: 'dataset', num: files.length })}
} + {files?.length > 1 &&
{t('preprocessDocument', { ns: 'dataset', num: files.length })}
} {files?.length > 0 ? ( = ({
)}
- - -
+ + ) } export default React.memo(PreviewDocumentPicker) diff --git a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx index fcaca86e89..ba058860d3 100644 --- a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx +++ b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx @@ -1,6 +1,6 @@ +import { toast } from '@langgenius/dify-ui/toast' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments } from '@/service/knowledge/use-document' import AutoDisabledDocument from '../auto-disabled-document' @@ -30,7 +30,7 @@ vi.mock('@/service/knowledge/use-document', () => ({ useInvalidDisabledDocument: vi.fn(() => mockInvalidDisabledDocument), })) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: { success: mockToastSuccess, }, diff --git a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx index c6c7e03bd1..c150ac823f 100644 --- a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx +++ b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' +import { toast } from '@langgenius/dify-ui/toast' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments, useDocumentEnable, useInvalidDisabledDocument } from '@/service/knowledge/use-document' import StatusWithAction from './status-with-action' diff --git a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx index 002a2323a7..3c2536fa5e 100644 --- a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx +++ b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' +import { cn } from '@langgenius/dify-ui/cn' import { RiAlertFill, RiCheckboxCircleFill, RiErrorWarningFill, RiInformation2Fill } from '@remixicon/react' import * as React from 'react' import Divider from '@/app/components/base/divider' -import { cn } from '@/utils/classnames' type Status = 'success' | 'error' | 'warning' | 'info' type Props = { @@ -46,7 +46,7 @@ const StatusAction: FC = ({ }) => { const { Icon, color } = getIcon(type) return ( -
+
{ it('should render without crashing', () => { const images = createMockImages(3) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should render all images when count is below limit', () => { @@ -87,7 +87,8 @@ describe('ImageList', () => { const images = createMockImages(15) render() // More button should be visible - expect(screen.getByText(/\+6/)).toBeInTheDocument() + // More button should be visible + expect(screen.getByText(/\+6/))!.toBeInTheDocument() }) }) @@ -97,33 +98,35 @@ describe('ImageList', () => { const { container } = render( , ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) it('should use default limit of 9', () => { const images = createMockImages(12) render() // Should show "+3" for remaining images - expect(screen.getByText(/\+3/)).toBeInTheDocument() + // Should show "+3" for remaining images + expect(screen.getByText(/\+3/))!.toBeInTheDocument() }) it('should respect custom limit', () => { const images = createMockImages(10) render() // Should show "+5" for remaining images - expect(screen.getByText(/\+5/)).toBeInTheDocument() + // Should show "+5" for remaining images + expect(screen.getByText(/\+5/))!.toBeInTheDocument() }) it('should handle size prop sm', () => { const images = createMockImages(2) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should handle size prop md', () => { const images = createMockImages(2) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) }) @@ -135,6 +138,37 @@ describe('ImageList', () => { const moreButton = screen.getByText(/\+6/) fireEvent.click(moreButton) + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear // More button should disappear expect(screen.queryByText(/\+6/)).not.toBeInTheDocument() }) @@ -146,9 +180,10 @@ describe('ImageList', () => { // Find and click an image thumbnail const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]') if (thumbnails.length > 0) { - fireEvent.click(thumbnails[0]) + fireEvent.click(thumbnails[0]!) // Preview should open - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + // Preview should open + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() } }) @@ -159,12 +194,43 @@ describe('ImageList', () => { // Open preview const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]') if (thumbnails.length > 0) { - fireEvent.click(thumbnails[0]) + fireEvent.click(thumbnails[0]!) // Close preview const closeButton = screen.getByTestId('close-preview') fireEvent.click(closeButton) + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed // Preview should be closed expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument() } @@ -174,7 +240,7 @@ describe('ImageList', () => { describe('Edge Cases', () => { it('should handle empty images array', () => { const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should not open preview when clicked image not found in list (index === -1)', () => { @@ -185,7 +251,8 @@ describe('ImageList', () => { fireEvent.click(firstThumb) // Preview should open for valid image - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + // Preview should open for valid image + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() // Close preview fireEvent.click(screen.getByTestId('close-preview')) @@ -197,7 +264,7 @@ describe('ImageList', () => { const validThumb = screen.getByTestId('file-thumb-https://example.com/image-1.png') fireEvent.click(validThumb) - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() }) it('should return early when file sourceUrl is not found in limitedImages (index === -1)', () => { @@ -216,6 +283,37 @@ describe('ImageList', () => { }) } + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages // Preview should NOT open because the file was not found in limitedImages expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument() }) @@ -223,7 +321,7 @@ describe('ImageList', () => { it('should handle single image', () => { const images = createMockImages(1) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should not show More button when images count equals limit', () => { @@ -236,13 +334,45 @@ describe('ImageList', () => { const images = createMockImages(5) render() // Should show "+5" for all images - expect(screen.getByText(/\+5/)).toBeInTheDocument() + // Should show "+5" for all images + expect(screen.getByText(/\+5/))!.toBeInTheDocument() }) it('should handle limit larger than images count', () => { const images = createMockImages(5) render() // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button expect(screen.queryByText(/\+/)).not.toBeInTheDocument() }) }) diff --git a/web/app/components/datasets/common/image-list/index.tsx b/web/app/components/datasets/common/image-list/index.tsx index 48a990f94d..b04dd5afbd 100644 --- a/web/app/components/datasets/common/image-list/index.tsx +++ b/web/app/components/datasets/common/image-list/index.tsx @@ -1,8 +1,8 @@ import type { ImageInfo } from '../image-previewer' import type { FileEntity } from '@/app/components/base/file-thumb' +import { cn } from '@langgenius/dify-ui/cn' import { useCallback, useMemo, useState } from 'react' import FileThumb from '@/app/components/base/file-thumb' -import { cn } from '@/utils/classnames' import ImagePreviewer from '../image-previewer' import More from './more' diff --git a/web/app/components/datasets/common/image-list/more.tsx b/web/app/components/datasets/common/image-list/more.tsx index be6b53a5a5..255e5e4d87 100644 --- a/web/app/components/datasets/common/image-list/more.tsx +++ b/web/app/components/datasets/common/image-list/more.tsx @@ -32,7 +32,7 @@ const More = ({ count, onClick }: MoreProps) => {
-
+
) } diff --git a/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx b/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx index 44f0e95c08..1ff5cb33f8 100644 --- a/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx @@ -65,7 +65,8 @@ describe('ImagePreviewer', () => { }) // Should render in portal - expect(document.body.querySelector('.image-previewer')).toBeInTheDocument() + // Should render in portal + expect(document.body.querySelector('.image-previewer'))!.toBeInTheDocument() }) it('should render close button', async () => { @@ -77,7 +78,8 @@ describe('ImagePreviewer', () => { }) // Esc text should be visible - expect(screen.getByText('Esc')).toBeInTheDocument() + // Esc text should be visible + expect(screen.getByText('Esc'))!.toBeInTheDocument() }) it('should show loading state initially', async () => { @@ -92,7 +94,8 @@ describe('ImagePreviewer', () => { }) // Loading component should be visible - expect(document.body.querySelector('.image-previewer')).toBeInTheDocument() + // Loading component should be visible + expect(document.body.querySelector('.image-previewer'))!.toBeInTheDocument() }) }) @@ -107,7 +110,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should start at second image - expect(screen.getByText('image2.png')).toBeInTheDocument() + // Should start at second image + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) }) @@ -120,7 +124,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) }) }) @@ -151,7 +155,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) // Find and click next button (right arrow) @@ -166,7 +170,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image2.png')).toBeInTheDocument() + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) } }) @@ -180,7 +184,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image2.png')).toBeInTheDocument() + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) // Find and click prev button (left arrow) @@ -195,7 +199,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) } }) @@ -213,7 +217,7 @@ describe('ImagePreviewer', () => { btn.className.includes('left-8'), ) - expect(prevButton).toBeDisabled() + expect(prevButton)!.toBeDisabled() }) it('should disable next button at last image', async () => { @@ -229,7 +233,7 @@ describe('ImagePreviewer', () => { btn.className.includes('right-8'), ) - expect(nextButton).toBeDisabled() + expect(nextButton)!.toBeDisabled() }) }) @@ -258,7 +262,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) }) @@ -275,7 +279,7 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Retry button should be visible const retryButton = document.querySelector('button.rounded-full') - expect(retryButton).toBeInTheDocument() + expect(retryButton)!.toBeInTheDocument() }) }) }) @@ -290,7 +294,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) const buttons = document.querySelectorAll('button') @@ -306,7 +310,7 @@ describe('ImagePreviewer', () => { // Should still be at first image await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) } }) @@ -320,7 +324,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image3.png')).toBeInTheDocument() + expect(screen.getByText('image3.png'))!.toBeInTheDocument() }) const buttons = document.querySelectorAll('button') @@ -336,7 +340,7 @@ describe('ImagePreviewer', () => { // Should still be at last image await waitFor(() => { - expect(screen.getByText('image3.png')).toBeInTheDocument() + expect(screen.getByText('image3.png'))!.toBeInTheDocument() }) } }) @@ -366,7 +370,7 @@ describe('ImagePreviewer', () => { // Wait for error state await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) const retryButton = document.querySelector('button.rounded-full') @@ -393,7 +397,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) // Find and click the retry button (not the nav buttons) @@ -402,7 +406,7 @@ describe('ImagePreviewer', () => { btn.className.includes('rounded-full') && !btn.className.includes('left-8') && !btn.className.includes('right-8'), ) - expect(retryButton).toBeInTheDocument() + expect(retryButton)!.toBeInTheDocument() if (retryButton) { mockFetch.mockClear() @@ -451,7 +455,7 @@ describe('ImagePreviewer', () => { describe('Edge Cases', () => { it('should handle single image', async () => { const onClose = vi.fn() - const images = [createMockImages()[0]] + const images = [createMockImages()[0]!] await act(async () => { render() @@ -466,8 +470,8 @@ describe('ImagePreviewer', () => { btn.className.includes('right-8'), ) - expect(prevButton).toBeDisabled() - expect(nextButton).toBeDisabled() + expect(prevButton)!.toBeDisabled() + expect(nextButton)!.toBeDisabled() }) it('should stop event propagation on container click', async () => { @@ -500,7 +504,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should display dimensions (800 × 600 from MockImage) - expect(screen.getByText(/800.*600/)).toBeInTheDocument() + // Should display dimensions (800 × 600 from MockImage) + expect(screen.getByText(/800.*600/))!.toBeInTheDocument() }) }) @@ -514,7 +519,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should display formatted file size - expect(screen.getByText('image1.png')).toBeInTheDocument() + // Should display formatted file size + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/common/image-previewer/index.tsx b/web/app/components/datasets/common/image-previewer/index.tsx index 1164b4023a..42b4ce33a9 100644 --- a/web/app/components/datasets/common/image-previewer/index.tsx +++ b/web/app/components/datasets/common/image-previewer/index.tsx @@ -1,8 +1,8 @@ +import { Button } from '@langgenius/dify-ui/button' import { RiArrowLeftLine, RiArrowRightLine, RiCloseLine, RiRefreshLine } from '@remixicon/react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { createPortal } from 'react-dom' import { useHotkeys } from 'react-hotkeys-hook' -import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import { formatFileSize } from '@/utils/format' @@ -137,7 +137,7 @@ const ImagePreviewer = ({ return { ...prev, [image.url]: { - ...prev[image.url], + ...prev[image.url]!, status: 'loading', }, } @@ -155,11 +155,11 @@ const ImagePreviewer = ({ onClick={e => e.stopPropagation()} tabIndex={-1} > -
+
- {cachedImages[currentImage.url].status === 'loading' && ( + {cachedImages[currentImage!.url]!.status === 'loading' && ( )} - {cachedImages[currentImage.url].status === 'error' && ( -
- {`Failed to load image: ${currentImage.url}. Please try again.`} + {cachedImages[currentImage!.url]!.status === 'error' && ( +
+ {`Failed to load image: ${currentImage!.url}. Please try again.`}
)} - {cachedImages[currentImage.url].status === 'loaded' && ( + {cachedImages[currentImage!.url]!.status === 'loaded' && (
{currentImage.name} -
- {currentImage.name} +
+ {currentImage!.name} · - {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} + {`${cachedImages[currentImage!.url]!.width} ×  ${cachedImages[currentImage!.url]!.height}`} · - {formatFileSize(currentImage.size)} + {formatFileSize(currentImage!.size)}
)} @@ -173,7 +173,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('rerank-switch')).toBeInTheDocument() + expect(screen.getByTestId('rerank-switch'))!.toBeInTheDocument() }) it('should render model selector when reranking is enabled', () => { @@ -186,7 +186,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() }) it('should not render model selector when reranking is disabled', () => { @@ -212,8 +212,8 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() - expect(screen.getByTestId('top-k-item')).toHaveAttribute('data-value', '5') + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toHaveAttribute('data-value', '5') }) it('should render score threshold item when reranking is enabled', () => { @@ -226,7 +226,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) it('should toggle reranking enable', () => { @@ -349,7 +349,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip')).toBeInTheDocument() + expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip'))!.toBeInTheDocument() }) it('should not show multimodal tip when showMultiModalTip is false', () => { @@ -378,7 +378,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('rerank-switch')).toBeInTheDocument() + expect(screen.getByTestId('rerank-switch'))!.toBeInTheDocument() }) it('should hide score threshold when reranking is disabled for full text search', () => { @@ -410,7 +410,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) }) @@ -451,7 +451,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() }) it('should not render score threshold for keyword search', () => { @@ -496,7 +496,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('dataset.weightedScore.title')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.title'))!.toBeInTheDocument() }) it('should have RerankingModel option', () => { @@ -508,7 +508,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() }) it('should show model selector when RerankingModel mode is selected', () => { @@ -520,7 +520,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() }) it('should show WeightedScore component when WeightedScore mode is selected', () => { @@ -547,7 +547,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('weighted-score')).toBeInTheDocument() + expect(screen.getByTestId('weighted-score'))!.toBeInTheDocument() expect(screen.queryByTestId('model-selector')).not.toBeInTheDocument() }) @@ -565,7 +565,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.reranking_mode).toBe(RerankingModeEnum.WeightedScore) expect(calledWith.weights).toBeDefined() }) @@ -645,7 +645,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(screen.getByTestId('change-weights-btn')) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights.vector_setting.vector_weight).toBe(0.6) expect(calledWith.weights.keyword_setting.keyword_weight).toBe(0.4) }) @@ -659,8 +659,8 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) it('should update top_k for hybrid search', () => { @@ -724,7 +724,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip')).toBeInTheDocument() + expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip'))!.toBeInTheDocument() }) it('should not show multimodal tip for hybrid search with WeightedScore', () => { @@ -799,7 +799,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('tooltip')).toBeInTheDocument() + expect(screen.getByTestId('tooltip'))!.toBeInTheDocument() }) }) @@ -814,7 +814,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() }) }) @@ -838,7 +838,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights).toBeDefined() expect(calledWith.weights.weight_type).toBe(WeightedScoreEnum.Customized) }) @@ -872,7 +872,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights.vector_setting.vector_weight).toBe(0.8) }) }) diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index acb77acfa6..8514e1fae2 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -1,8 +1,11 @@ 'use client' import type { FC } from 'react' import type { RetrievalConfig } from '@/types/app' -import * as React from 'react' +import { cn } from '@langgenius/dify-ui/cn' +import { Switch } from '@langgenius/dify-ui/switch' +import { toast } from '@langgenius/dify-ui/toast' +import * as React from 'react' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' @@ -10,9 +13,7 @@ import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/aler import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import TopKItem from '@/app/components/base/param-item/top-k-item' import RadioCard from '@/app/components/base/radio-card' -import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' -import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' @@ -22,7 +23,6 @@ import { WeightedScoreEnum, } from '@/models/datasets' import { RETRIEVE_METHOD } from '@/types/app' -import { cn } from '@/utils/classnames' import ProgressIndicator from '../../create/assets/progress-indicator.svg' import Reranking from '../../create/assets/rerank.svg' @@ -121,12 +121,12 @@ const RetrievalParamConfig: FC = ({ {canToggleRerankModalEnable && ( )}
- {t('modelProvider.rerankModel.key', { ns: 'common' })} + {t('modelProvider.rerankModel.key', { ns: 'common' })} {t('modelProvider.rerankModel.tip', { ns: 'common' })}
@@ -152,11 +152,11 @@ const RetrievalParamConfig: FC = ({ /> {showMultiModalTip && (
-
+
- + {t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
@@ -246,11 +246,11 @@ const RetrievalParamConfig: FC = ({ ...value.weights!, vector_setting: { ...value.weights!.vector_setting, - vector_weight: v.value[0], + vector_weight: v.value[0]!, }, keyword_setting: { ...value.weights!.keyword_setting, - keyword_weight: v.value[1], + keyword_weight: v.value[1]!, }, }, }) @@ -276,11 +276,11 @@ const RetrievalParamConfig: FC = ({ /> {showMultiModalTip && (
-
+
- + {t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx index 3a0d5f6d63..b5c73f3422 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx @@ -52,7 +52,7 @@ const toastMocks = vi.hoisted(() => { } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: toastMocks.api, })) @@ -113,7 +113,7 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.importFromDSL')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSL'))!.toBeInTheDocument() }) it('should not render modal content when show is false', () => { @@ -139,8 +139,8 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.importFromDSLFile')).toBeInTheDocument() - expect(screen.getByText('app.importFromDSLUrl')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSLFile'))!.toBeInTheDocument() + expect(screen.getByText('app.importFromDSLUrl'))!.toBeInTheDocument() }) it('should render cancel and import buttons', () => { @@ -152,8 +152,8 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.newApp.Cancel')).toBeInTheDocument() - expect(screen.getByText('app.newApp.import')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Cancel'))!.toBeInTheDocument() + expect(screen.getByText('app.newApp.import'))!.toBeInTheDocument() }) it('should render uploader when file tab is active', () => { @@ -166,7 +166,7 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() }) it('should render URL input when URL tab is active', () => { @@ -179,8 +179,8 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('DSL URL')).toBeInTheDocument() - expect(screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder')).toBeInTheDocument() + expect(screen.getByText('DSL URL'))!.toBeInTheDocument() + expect(screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder'))!.toBeInTheDocument() }) }) @@ -195,7 +195,8 @@ describe('CreateFromDSLModal', () => { ) // File tab content should be visible - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + // File tab content should be visible + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() }) it('should use provided activeTab prop', () => { @@ -208,7 +209,7 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('DSL URL')).toBeInTheDocument() + expect(screen.getByText('DSL URL'))!.toBeInTheDocument() }) it('should use provided dslUrl prop', () => { @@ -223,7 +224,7 @@ describe('CreateFromDSLModal', () => { ) const input = screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder') - expect(input).toHaveValue('https://example.com/test.pipeline') + expect(input)!.toHaveValue('https://example.com/test.pipeline') }) it('should call onClose when cancel button is clicked', () => { @@ -252,12 +253,14 @@ describe('CreateFromDSLModal', () => { ) // Initially file tab is active - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + // Initially file tab is active + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() fireEvent.click(screen.getByText('app.importFromDSLUrl')) // URL input should be visible - expect(screen.getByText('DSL URL')).toBeInTheDocument() + // URL input should be visible + expect(screen.getByText('DSL URL'))!.toBeInTheDocument() }) it('should update URL value when typing', () => { @@ -273,7 +276,7 @@ describe('CreateFromDSLModal', () => { const input = screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder') fireEvent.change(input, { target: { value: 'https://example.com/test.pipeline' } }) - expect(input).toHaveValue('https://example.com/test.pipeline') + expect(input)!.toHaveValue('https://example.com/test.pipeline') }) it('should have disabled import button when no file is selected in file tab', () => { @@ -287,7 +290,7 @@ describe('CreateFromDSLModal', () => { ) const importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() }) it('should have disabled import button when no URL is entered in URL tab', () => { @@ -301,7 +304,7 @@ describe('CreateFromDSLModal', () => { ) const importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() }) it('should enable import button when URL is entered', () => { @@ -445,7 +448,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) vi.useRealTimers() @@ -663,14 +666,14 @@ describe('CreateFromDSLModal', () => { // File tab with no file - disabled let importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() // Switch to URL tab by clicking on it fireEvent.click(screen.getByText('app.importFromDSLUrl')) // Still disabled (no URL) importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() // Add URL value - should enable const input = screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder') @@ -902,7 +905,7 @@ describe('CreateFromDSLModal', () => { // Wait for file to be displayed await waitFor(() => { - expect(screen.getByText('test.pipeline')).toBeInTheDocument() + expect(screen.getByText('test.pipeline'))!.toBeInTheDocument() }) // Now remove the file by clicking delete button (inside ActionButton) @@ -912,7 +915,7 @@ describe('CreateFromDSLModal', () => { fireEvent.click(deleteButton) // File should be removed - uploader prompt should show again await waitFor(() => { - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() }) } }) @@ -968,7 +971,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1016,7 +1019,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1056,7 +1059,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1103,7 +1106,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1144,13 +1147,13 @@ describe('CreateFromDSLModal', () => { // Error modal should be visible await waitFor(() => { - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) // There are two Cancel buttons now (one in main modal footer, one in error modal) // Find the Cancel button in the error modal context const cancelButtons = screen.getAllByText('app.newApp.Cancel') - fireEvent.click(cancelButtons[cancelButtons.length - 1]) + fireEvent.click(cancelButtons[cancelButtons.length - 1]!) vi.useRealTimers() }) @@ -1166,14 +1169,14 @@ describe('Header', () => { describe('Rendering', () => { it('should render title', () => { render(
) - expect(screen.getByText('app.importFromDSL')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSL'))!.toBeInTheDocument() }) it('should render close icon', () => { render(
) // Check for close icon container const closeButton = document.querySelector('[class*="cursor-pointer"]') - expect(closeButton).toBeInTheDocument() + expect(closeButton)!.toBeInTheDocument() }) }) @@ -1205,8 +1208,8 @@ describe('Tab', () => { />, ) - expect(screen.getByText('app.importFromDSLFile')).toBeInTheDocument() - expect(screen.getByText('app.importFromDSLUrl')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSLFile'))!.toBeInTheDocument() + expect(screen.getByText('app.importFromDSLUrl'))!.toBeInTheDocument() }) }) @@ -1223,7 +1226,7 @@ describe('Tab', () => { fireEvent.click(screen.getByText('app.importFromDSLFile')) // Tab uses bind() which passes the key as first argument and event as second expect(setCurrentTab).toHaveBeenCalled() - expect(setCurrentTab.mock.calls[0][0]).toBe(CreateFromDSLModalTab.FROM_FILE) + expect(setCurrentTab.mock.calls[0]![0]).toBe(CreateFromDSLModalTab.FROM_FILE) }) it('should call setCurrentTab when clicking URL tab', () => { @@ -1238,7 +1241,7 @@ describe('Tab', () => { fireEvent.click(screen.getByText('app.importFromDSLUrl')) // Tab uses bind() which passes the key as first argument and event as second expect(setCurrentTab).toHaveBeenCalled() - expect(setCurrentTab.mock.calls[0][0]).toBe(CreateFromDSLModalTab.FROM_URL) + expect(setCurrentTab.mock.calls[0]![0]).toBe(CreateFromDSLModalTab.FROM_URL) }) }) }) @@ -1259,7 +1262,7 @@ describe('TabItem', () => { />, ) - expect(screen.getByText('Test Tab')).toBeInTheDocument() + expect(screen.getByText('Test Tab'))!.toBeInTheDocument() }) it('should render active indicator when active', () => { @@ -1273,7 +1276,7 @@ describe('TabItem', () => { // Active indicator is the bottom border div const indicator = document.querySelector('[class*="bg-util-colors-blue"]') - expect(indicator).toBeInTheDocument() + expect(indicator)!.toBeInTheDocument() }) it('should not render active indicator when inactive', () => { @@ -1348,8 +1351,8 @@ describe('Uploader', () => { />, ) - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() - expect(screen.getByText('app.dslUploader.browse')).toBeInTheDocument() + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() + expect(screen.getByText('app.dslUploader.browse'))!.toBeInTheDocument() }) it('should render file info when file is selected', () => { @@ -1362,8 +1365,8 @@ describe('Uploader', () => { />, ) - expect(screen.getByText('test.pipeline')).toBeInTheDocument() - expect(screen.getByText('PIPELINE')).toBeInTheDocument() + expect(screen.getByText('test.pipeline'))!.toBeInTheDocument() + expect(screen.getByText('PIPELINE'))!.toBeInTheDocument() }) it('should apply custom className', () => { @@ -1375,7 +1378,7 @@ describe('Uploader', () => { />, ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) }) @@ -1484,7 +1487,8 @@ describe('Uploader', () => { dropArea.dispatchEvent(dragOverEvent) // Event should be handled without errors - expect(dropArea).toBeInTheDocument() + // Event should be handled without errors + expect(dropArea)!.toBeInTheDocument() }) it('should handle dragLeave event and reset dragging state when target is dragRef', async () => { @@ -1512,8 +1516,8 @@ describe('Uploader', () => { }) // The dragRef div appears when dragging is true - const dragRefDiv = document.querySelector('[class*="absolute left-0 top-0"]') - expect(dragRefDiv).toBeInTheDocument() + const dragRefDiv = document.querySelector('[class*="absolute top-0 left-0"]') + expect(dragRefDiv)!.toBeInTheDocument() // When dragLeave happens on the dragRef element, setDragging(false) is called if (dragRefDiv) { @@ -1648,7 +1652,7 @@ describe('Uploader', () => { // Get the file input const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement - expect(fileInput).toBeInTheDocument() + expect(fileInput)!.toBeInTheDocument() // Spy on input click before triggering selectHandle const clickSpy = vi.spyOn(fileInput, 'click').mockImplementation(() => { @@ -1688,7 +1692,7 @@ describe('Uploader', () => { }) // Now the dragRef div should exist - const dragRefDiv = document.querySelector('[class*="absolute left-0 top-0"]') + const dragRefDiv = document.querySelector('[class*="absolute top-0 left-0"]') // When dragEnter happens on dragRef itself, setDragging should NOT be called if (dragRefDiv) { @@ -1715,7 +1719,7 @@ describe('Uploader', () => { // Find and click delete button const deleteButton = document.querySelector('button') - expect(deleteButton).toBeInTheDocument() + expect(deleteButton)!.toBeInTheDocument() if (deleteButton) { fireEvent.click(deleteButton) @@ -1745,7 +1749,7 @@ describe('DSLConfirmModal', () => { />, ) - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) it('should render version information', () => { @@ -1760,8 +1764,8 @@ describe('DSLConfirmModal', () => { />, ) - expect(screen.getByText('0.9.0')).toBeInTheDocument() - expect(screen.getByText('1.0.0')).toBeInTheDocument() + expect(screen.getByText('0.9.0'))!.toBeInTheDocument() + expect(screen.getByText('1.0.0'))!.toBeInTheDocument() }) it('should render cancel and confirm buttons', () => { @@ -1772,8 +1776,8 @@ describe('DSLConfirmModal', () => { />, ) - expect(screen.getByText('app.newApp.Cancel')).toBeInTheDocument() - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Cancel'))!.toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) it('should render with default empty versions', () => { @@ -1785,7 +1789,8 @@ describe('DSLConfirmModal', () => { ) // Should not crash with default empty strings - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + // Should not crash with default empty strings + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) it('should disable confirm button when confirmDisabled is true', () => { @@ -1798,7 +1803,7 @@ describe('DSLConfirmModal', () => { ) const confirmButton = screen.getByText('app.newApp.Confirm').closest('button') - expect(confirmButton).toBeDisabled() + expect(confirmButton)!.toBeDisabled() }) }) @@ -1879,7 +1884,8 @@ describe('DSLConfirmModal', () => { ) // Component should render without crashing - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + // Component should render without crashing + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) it('should use default confirmDisabled when not provided', () => { @@ -1988,7 +1994,7 @@ describe('CreateFromDSLModal Integration', () => { // Verify error modal is shown await waitFor(() => { - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) vi.useRealTimers() diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx index f553f491e5..f4263ab439 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx @@ -16,7 +16,7 @@ const { mockToast } = vi.hoisted(() => { return { mockToast } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx index db675228c4..d0c97e185c 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx @@ -1,5 +1,5 @@ +import { Button } from '@langgenius/dify-ui/button' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' type DSLConfirmModalProps = { @@ -27,7 +27,7 @@ const DSLConfirmModal = ({ >
{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}
-
+
{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}
{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}

@@ -43,7 +43,7 @@ const DSLConfirmModal = ({
- +
) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx index 5e4cb88c27..cf759880be 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx @@ -12,10 +12,10 @@ const Header = ({ const { t } = useTranslation() return ( -
+
{t('importFromDSL', { ns: 'app' })}
diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx index 2a0f7f2cf0..b95fa4d5db 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx @@ -49,7 +49,7 @@ const toastMocks = vi.hoisted(() => { } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: toastMocks.api, })) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts index 992d0526be..e085ffe1bc 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts @@ -1,8 +1,8 @@ 'use client' +import { toast } from '@langgenius/dify-ui/toast' import { useDebounceFn } from 'ahooks' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { toast } from '@/app/components/base/ui/toast' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { DSLImportMode, DSLImportStatus } from '@/models/app' import { useRouter } from '@/next/navigation' diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx index 079ea90687..4f5b5c23fd 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx @@ -1,8 +1,8 @@ 'use client' +import { Button } from '@langgenius/dify-ui/button' import { useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import DSLConfirmModal from './dsl-confirm-modal' @@ -78,7 +78,7 @@ const CreateFromDSLModal = ({ )} {currentTab === CreateFromDSLModalTab.FROM_URL && (
-
+
DSL URL
+
{ tabs.map(tab => ( = ({ file, updateFile, className }) => {
- {dragging &&
} + {dragging &&
}
)} {file && ( @@ -104,10 +104,10 @@ const Uploader: FC = ({ file, updateFile, className }) => {
- + {file.name} -
+
PIPELINE · {formatFileSize(file.size)} diff --git a/web/app/components/datasets/create-from-pipeline/footer.tsx b/web/app/components/datasets/create-from-pipeline/footer.tsx index ae1bb48394..dee48163d2 100644 --- a/web/app/components/datasets/create-from-pipeline/footer.tsx +++ b/web/app/components/datasets/create-from-pipeline/footer.tsx @@ -39,11 +39,11 @@ const Footer = () => { }, [invalidDatasetList]) return ( -
+
diff --git a/web/app/components/datasets/create-from-pipeline/index.tsx b/web/app/components/datasets/create-from-pipeline/index.tsx index 0c788f613d..f5b5c2ec8b 100644 --- a/web/app/components/datasets/create-from-pipeline/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/index.tsx @@ -9,7 +9,7 @@ const CreateFromPipeline = () => {
- +