mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:06:51 +08:00
merge evaluation-fe
This commit is contained in:
commit
47050b8d15
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
api-unit:
|
||||
name: API Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
COVERAGE_FILE: coverage-unit
|
||||
defaults:
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
api-integration:
|
||||
name: API Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
COVERAGE_FILE: coverage-integration
|
||||
STORAGE_TYPE: opendal
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
|
||||
api-coverage:
|
||||
name: API Coverage
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
needs:
|
||||
- api-unit
|
||||
- api-integration
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -13,7 +13,7 @@ permissions:
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
|
||||
46
.github/workflows/build-push.yml
vendored
46
.github/workflows/build-push.yml
vendored
@ -26,6 +26,9 @@ jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
@ -35,28 +38,28 @@ jobs:
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@ -70,8 +73,8 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@v1
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
@ -81,16 +84,15 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
uses: depot/build-push-action@v1
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=${{ matrix.service_name }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
||||
|
||||
- name: Export digest
|
||||
env:
|
||||
@ -108,9 +110,33 @@ jobs:
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
fork-build-validate:
|
||||
if: github.repository != 'langgenius/dify'
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "validate-api-amd64"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "validate-web-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||
|
||||
- name: Validate Docker image
|
||||
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
|
||||
create-manifest:
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -9,7 +9,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
db-migration-test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -59,7 +59,7 @@ jobs:
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
|
||||
2
.github/workflows/deploy-enterprise.yml
vendored
2
.github/workflows/deploy-enterprise.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
|
||||
47
.github/workflows/docker-build.yml
vendored
47
.github/workflows/docker-build.yml
vendored
@ -14,40 +14,69 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@v1
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
uses: depot/build-push-action@v1
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
build-docker-fork:
|
||||
if: github.event.pull_request.head.repo.full_name != github.repository
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -7,7 +7,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
with:
|
||||
|
||||
24
.github/workflows/main-ci.yml
vendored
24
.github/workflows/main-ci.yml
vendored
@ -23,7 +23,7 @@ concurrency:
|
||||
jobs:
|
||||
pre_job:
|
||||
name: Skip Duplicate Checks
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
|
||||
steps:
|
||||
@ -39,7 +39,7 @@ jobs:
|
||||
name: Check Changed Files
|
||||
needs: pre_job
|
||||
if: needs.pre_job.outputs.should_skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
@ -141,7 +141,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped API tests
|
||||
run: echo "No API-related changes detected; skipping API tests."
|
||||
@ -154,7 +154,7 @@ jobs:
|
||||
- check-changes
|
||||
- api-tests-run
|
||||
- api-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize API Tests status
|
||||
env:
|
||||
@ -201,7 +201,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped web tests
|
||||
run: echo "No web-related changes detected; skipping web tests."
|
||||
@ -214,7 +214,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-tests-run
|
||||
- web-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize Web Tests status
|
||||
env:
|
||||
@ -260,7 +260,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped web full-stack e2e
|
||||
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
|
||||
@ -273,7 +273,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-e2e-run
|
||||
- web-e2e-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize Web Full-Stack E2E status
|
||||
env:
|
||||
@ -325,7 +325,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped VDB tests
|
||||
run: echo "No VDB-related changes detected; skipping VDB tests."
|
||||
@ -338,7 +338,7 @@ jobs:
|
||||
- check-changes
|
||||
- vdb-tests-run
|
||||
- vdb-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize VDB Tests status
|
||||
env:
|
||||
@ -384,7 +384,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped DB migration tests
|
||||
run: echo "No migration-related changes detected; skipping DB migration tests."
|
||||
@ -397,7 +397,7 @@ jobs:
|
||||
- check-changes
|
||||
- db-migration-test-run
|
||||
- db-migration-test-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize DB Migration Test status
|
||||
env:
|
||||
|
||||
2
.github/workflows/pyrefly-diff-comment.yml
vendored
2
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with pyrefly diff
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-diff:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
2
.github/workflows/pyrefly-type-coverage.yml
vendored
2
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@ -15,7 +15,7 @@ permissions:
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -57,7 +57,7 @@ jobs:
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
@ -108,6 +108,8 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
@ -129,7 +131,7 @@ jobs:
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
2
.github/workflows/tool-test-sdks.yaml
vendored
2
.github/workflows/tool-test-sdks.yaml
vendored
@ -18,7 +18,7 @@ concurrency:
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
defaults:
|
||||
run:
|
||||
|
||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -35,7 +35,7 @@ concurrency:
|
||||
jobs:
|
||||
translate:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
|
||||
uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Smoke Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
|
||||
2
.github/workflows/web-e2e.yml
vendored
2
.github/workflows/web-e2e.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Full-Stack E2E
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
@ -54,7 +54,7 @@ jobs:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
@ -92,7 +92,7 @@ jobs:
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
|
||||
@ -147,7 +147,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagBindingItemDeletePayload(BaseModel):
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
@ -70,6 +75,7 @@ register_schema_models(
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagBindingItemDeletePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_binding() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=payload.tag_id,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings")
|
||||
class TagBindingCollectionApi(Resource):
|
||||
"""Canonical collection resource for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
return _create_tag_bindings()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
|
||||
@console_ns.route("/tag-bindings/<uuid:id>")
|
||||
class TagBindingItemApi(Resource):
|
||||
"""Canonical item resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding")
|
||||
@console_ns.doc(params={"id": "Tag ID"})
|
||||
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
_require_tag_binding_edit_permission()
|
||||
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=str(id),
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class DeprecatedTagBindingCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
class DeprecatedTagBindingRemoveApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return _remove_tag_binding()
|
||||
|
||||
@ -527,6 +527,7 @@ class RetrievalService:
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
assert i.index_node_id
|
||||
segment_ids.append(i.segment_id)
|
||||
if i.segment_id in child_chunk_map:
|
||||
child_chunk_map[i.segment_id].append(i)
|
||||
|
||||
@ -11,6 +11,7 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.enums import SegmentType
|
||||
|
||||
|
||||
class DatasetDocumentStore:
|
||||
@ -127,6 +128,7 @@ class DatasetDocumentStore:
|
||||
if save_child:
|
||||
if doc.children:
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -137,7 +139,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type="automatic",
|
||||
type=SegmentType.AUTOMATIC,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
@ -163,6 +165,7 @@ class DatasetDocumentStore:
|
||||
)
|
||||
# add new child chunks
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -173,7 +176,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type="automatic",
|
||||
type=SegmentType.AUTOMATIC,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
|
||||
@ -1036,7 +1036,7 @@ class DocumentSegment(Base):
|
||||
return attachment_list
|
||||
|
||||
|
||||
class ChildChunk(Base):
|
||||
class ChildChunk(TypeBase):
|
||||
__tablename__ = "child_chunks"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||
@ -1046,29 +1046,42 @@ class ChildChunk(Base):
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
segment_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
content = mapped_column(LongText, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
# indexing fields
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(LongText, nullable=True)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime, nullable=True, insert_default=None, server_default=None, init=False
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime, nullable=True, insert_default=None, server_default=None, init=False
|
||||
)
|
||||
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'automatic'"),
|
||||
default=SegmentType.AUTOMATIC,
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
|
||||
@ -1867,15 +1867,18 @@ class MessageAnnotation(TypeBase):
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
StringUUID,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
question: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@ -225,8 +225,10 @@ class TestSpanBuilder:
|
||||
span = builder.build_span(span_data)
|
||||
assert isinstance(span, ReadableSpan)
|
||||
assert span.name == "test-span"
|
||||
assert span.context is not None
|
||||
assert span.context.trace_id == 123
|
||||
assert span.context.span_id == 456
|
||||
assert span.parent is not None
|
||||
assert span.parent.span_id == 789
|
||||
assert span.resource == resource
|
||||
assert span.attributes == {"attr1": "val1"}
|
||||
|
||||
@ -64,12 +64,13 @@ class TestSpanData:
|
||||
|
||||
def test_span_data_missing_required_fields(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SpanData(
|
||||
trace_id=123,
|
||||
# span_id missing
|
||||
name="test_span",
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
SpanData.model_validate(
|
||||
{
|
||||
"trace_id": 123,
|
||||
"name": "test_span",
|
||||
"start_time": 1000,
|
||||
"end_time": 2000,
|
||||
}
|
||||
)
|
||||
|
||||
def test_span_data_arbitrary_types_allowed(self):
|
||||
|
||||
@ -2,12 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
|
||||
import pytest
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
@ -44,7 +46,7 @@ class RecordingTraceClient:
|
||||
self.endpoint = endpoint
|
||||
self.added_spans: list[object] = []
|
||||
|
||||
def add_span(self, span) -> None:
|
||||
def add_span(self, span: object) -> None:
|
||||
self.added_spans.append(span)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags.SAMPLED,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
return Link(context)
|
||||
|
||||
|
||||
def _make_trace_metadata(
|
||||
trace_id: int = 1,
|
||||
workflow_span_id: int = 2,
|
||||
session_id: str = "s",
|
||||
user_id: str = "u",
|
||||
links: list[Link] | None = None,
|
||||
) -> TraceMetadata:
|
||||
return TraceMetadata(
|
||||
trace_id=trace_id,
|
||||
workflow_span_id=workflow_span_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
links=[] if links is None else links,
|
||||
)
|
||||
|
||||
|
||||
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
|
||||
return cast(RecordingTraceClient, trace_instance.trace_client)
|
||||
|
||||
|
||||
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
|
||||
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
|
||||
|
||||
|
||||
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
|
||||
defaults = {
|
||||
"workflow_id": "workflow-id",
|
||||
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
add_workflow_span.assert_called_once()
|
||||
passed_trace_metadata = add_workflow_span.call_args.args[1]
|
||||
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
|
||||
assert passed_trace_metadata.trace_id == 111
|
||||
assert passed_trace_metadata.workflow_span_id == 222
|
||||
assert passed_trace_metadata.session_id == "c"
|
||||
assert passed_trace_metadata.user_id == "u"
|
||||
assert passed_trace_metadata.links == []
|
||||
|
||||
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
|
||||
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
|
||||
|
||||
|
||||
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_message_trace_info(message_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
||||
)
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span, llm_span = trace_instance.trace_client.added_spans
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 2
|
||||
message_span, llm_span = spans
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.trace_id == 10
|
||||
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
||||
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
||||
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
|
||||
|
||||
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "dataset_retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
|
||||
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
||||
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_tool_trace_info(message_data=None)
|
||||
trace_instance.tool_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
assert _recording_trace_client(trace_instance).added_spans == []
|
||||
|
||||
|
||||
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
||||
)
|
||||
)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "my-tool"
|
||||
assert span.status == status
|
||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
|
||||
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
|
||||
|
||||
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
|
||||
|
||||
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
|
||||
|
||||
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
|
||||
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
|
||||
|
||||
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_execution.node_type = BuiltinNodeTypes.CODE
|
||||
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "title"
|
||||
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
|
||||
trace_metadata = _make_trace_metadata(links=[_make_link()])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "my-tool"
|
||||
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
|
||||
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
|
||||
)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "retrieval"
|
||||
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "llm"
|
||||
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
trace_metadata = _make_trace_metadata()
|
||||
|
||||
# CASE 1: With message_id
|
||||
trace_info = _make_workflow_trace_info(
|
||||
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
)
|
||||
trace_instance.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span = trace_instance.trace_client.added_spans[0]
|
||||
workflow_span = trace_instance.trace_client.added_spans[1]
|
||||
client = _recording_trace_client(trace_instance)
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 2
|
||||
message_span = spans[0]
|
||||
workflow_span = spans[1]
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.span_kind == SpanKind.SERVER
|
||||
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
||||
assert workflow_span.span_kind == SpanKind.INTERNAL
|
||||
assert workflow_span.parent_span_id == 20
|
||||
|
||||
trace_instance.trace_client.added_spans.clear()
|
||||
client.added_spans.clear()
|
||||
|
||||
# CASE 2: Without message_id
|
||||
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
|
||||
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "workflow"
|
||||
assert span.span_kind == SpanKind.SERVER
|
||||
assert span.parent_span_id is None
|
||||
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
|
||||
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
spans = _recorded_span_data(trace_instance)
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "suggested_question"
|
||||
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
|
||||
|
||||
def test_format_retrieval_documents():
|
||||
# Not a list
|
||||
assert format_retrieval_documents("not a list") == []
|
||||
assert format_retrieval_documents(cast(list[object], "not a list")) == []
|
||||
|
||||
# Valid list
|
||||
docs = [
|
||||
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
|
||||
|
||||
def test_format_input_messages():
|
||||
# Not a dict
|
||||
assert format_input_messages(None) == serialize_json_data([])
|
||||
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||
|
||||
# No prompts
|
||||
assert format_input_messages({}) == serialize_json_data([])
|
||||
@ -244,7 +246,7 @@ def test_format_input_messages():
|
||||
|
||||
def test_format_output_messages():
|
||||
# Not a dict
|
||||
assert format_output_messages(None) == serialize_json_data([])
|
||||
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||
|
||||
# No text
|
||||
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
|
||||
|
||||
@ -25,13 +25,13 @@ class TestAliyunConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
AliyunConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
AliyunConfig.model_validate({"license_key": "test_license"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -129,7 +130,7 @@ def test_set_span_status():
|
||||
return "SilentErrorRepr"
|
||||
|
||||
span.reset_mock()
|
||||
set_span_status(span, SilentError())
|
||||
set_span_status(span, cast(Exception | str | None, SilentError()))
|
||||
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
|
||||
|
||||
|
||||
|
||||
@ -28,13 +28,13 @@ class TestLangfuseConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
LangfuseConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
LangfuseConfig.model_validate({"public_key": "public"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
LangfuseConfig.model_validate({"secret_key": "secret"})
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
|
||||
|
||||
assert trace._get_completion_start_time(start_time, None) is None
|
||||
assert trace._get_completion_start_time(start_time, -1) is None
|
||||
assert trace._get_completion_start_time(start_time, "invalid") is None
|
||||
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None
|
||||
|
||||
@ -21,13 +21,13 @@ class TestLangSmithConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
LangSmithConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
LangSmithConfig.model_validate({"api_key": "key"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
LangSmithConfig.model_validate({"project": "project"})
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
|
||||
@ -599,7 +599,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_instance.message_trace(_make_message_trace_info())
|
||||
mock_tracing["start"].assert_called_once()
|
||||
@ -609,7 +608,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(error="something broke")
|
||||
trace_instance.message_trace(trace_info)
|
||||
@ -620,7 +618,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
monkeypatch.setenv("FILES_URL", "http://files.test")
|
||||
|
||||
file_data = SimpleNamespace(url="path/to/file.png")
|
||||
@ -638,7 +635,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
@ -651,7 +647,6 @@ class TestMessageTrace:
|
||||
|
||||
end_user = MagicMock()
|
||||
end_user.session_id = "session-xyz"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
|
||||
@ -664,7 +659,6 @@ class TestMessageTrace:
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"from_account_id": "acc-1"},
|
||||
|
||||
@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
|
||||
@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
|
||||
return instance
|
||||
|
||||
|
||||
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
|
||||
return cast(MagicMock, instance.add_trace)
|
||||
|
||||
|
||||
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
|
||||
return cast(MagicMock, instance.add_span)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _seed_to_uuid4
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
def test_root_span_is_created(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
assert instance.add_span.called
|
||||
assert _add_span_mock(instance).called
|
||||
|
||||
def test_root_span_id_matches_expected(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
expected = self._expected_root_span_id(trace_info)
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["id"] == expected
|
||||
|
||||
def test_root_span_has_no_parent(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["parent_span_id"] is None
|
||||
|
||||
def test_trace_name_is_workflow_trace(self):
|
||||
@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
||||
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_root_span_name_is_workflow_trace(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
def test_root_span_has_workflow_tag(self):
|
||||
trace_info = _make_workflow_trace_info(message_id=None)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert "workflow" in root_span_kwargs["tags"]
|
||||
|
||||
def test_node_execution_spans_are_parented_to_root(self):
|
||||
@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
# call_args_list[0] = root span, [1] = node execution span
|
||||
assert instance.add_span.call_count == 2
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
add_span = _add_span_mock(instance)
|
||||
assert add_span.call_count == 2
|
||||
node_span_kwargs = add_span.call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||
|
||||
def test_node_span_not_parented_to_workflow_app_log_id(self):
|
||||
@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] != old_parent_id
|
||||
|
||||
def test_root_span_id_differs_from_trace_id(self):
|
||||
@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
|
||||
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
|
||||
instance = self._run(trace_info)
|
||||
|
||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
||||
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
|
||||
|
||||
def test_root_span_uses_workflow_run_id_directly(self):
|
||||
@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
|
||||
instance = self._run(trace_info)
|
||||
|
||||
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
||||
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||
assert root_span_kwargs["id"] == expected_root_span_id
|
||||
|
||||
def test_root_span_id_differs_from_no_message_id_case(self):
|
||||
@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
|
||||
|
||||
instance = self._run(trace_info, node_executions=[node_exec])
|
||||
|
||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
||||
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, TypedDict, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
|
||||
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
|
||||
|
||||
metric_reader_instances: list[DummyMetricReader] = []
|
||||
meter_provider_instances: list[DummyMeterProvider] = []
|
||||
@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class PatchedCoreComponents(TypedDict):
|
||||
span_exporter: MagicMock
|
||||
span_processor: MagicMock
|
||||
tracer: MagicMock
|
||||
span: MagicMock
|
||||
tracer_provider: MagicMock
|
||||
logger: MagicMock
|
||||
trace_api: Any
|
||||
|
||||
|
||||
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Drop fake metric modules into sys.modules so the client imports resolve."""
|
||||
|
||||
@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
|
||||
span_exporter = MagicMock(name="span_exporter")
|
||||
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
|
||||
|
||||
@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
|
||||
return SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
|
||||
|
||||
def _build_client() -> TencentTraceClient:
|
||||
return TencentTraceClient(
|
||||
service_name="service",
|
||||
@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
|
||||
|
||||
|
||||
def test_resolve_grpc_target_handles_errors() -> None:
|
||||
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
|
||||
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
|
||||
client.record_trace_duration(0.5)
|
||||
|
||||
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
|
||||
@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
|
||||
logger.debug.assert_called()
|
||||
|
||||
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
ctx = _make_span_context(span_id=2)
|
||||
span.get_span_context.return_value = ctx
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
|
||||
span.add_event.assert_called_once()
|
||||
span.set_status.assert_called_once()
|
||||
span.end.assert_called_once_with(end_time=20)
|
||||
assert client.span_contexts[2] == "ctx"
|
||||
assert client.span_contexts[2] == ctx
|
||||
|
||||
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
client.span_contexts[10] = "existing"
|
||||
existing_context = _make_span_context(span_id=10)
|
||||
client.span_contexts[10] = existing_context
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "child"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=11)
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
|
||||
|
||||
client._create_and_export_span(data)
|
||||
trace_api = patch_core_components["trace_api"]
|
||||
trace_api.NonRecordingSpan.assert_called_once_with("existing")
|
||||
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
|
||||
trace_api.set_span_in_context.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||
client.tracer.start_span.side_effect = RuntimeError("boom")
|
||||
|
||||
client._create_and_export_span(
|
||||
@ -385,7 +407,7 @@ def test_get_project_url() -> None:
|
||||
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
|
||||
|
||||
|
||||
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
|
||||
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span_processor = patch_core_components["span_processor"]
|
||||
tracer_provider = patch_core_components["tracer_provider"]
|
||||
@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
|
||||
metric_reader.shutdown.assert_called_once()
|
||||
|
||||
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
meter_provider = meter_provider_instances[-1]
|
||||
meter_provider.shutdown.side_effect = RuntimeError("boom")
|
||||
assert client.metric_reader is not None
|
||||
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
|
||||
|
||||
client.shutdown()
|
||||
@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
|
||||
assert client.metric_reader is None
|
||||
|
||||
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
|
||||
logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||
|
||||
data = SpanData.model_construct(
|
||||
trace_id=1,
|
||||
@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
|
||||
hist_mock = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration = hist_mock
|
||||
|
||||
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
|
||||
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["foo"], str)
|
||||
assert attrs["bar"] == 2
|
||||
@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
||||
hist_mock = MagicMock(name="hist_trace_duration")
|
||||
client.hist_trace_duration = hist_mock
|
||||
|
||||
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
|
||||
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["meta"], str)
|
||||
assert attrs["ok"] is True
|
||||
@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
||||
],
|
||||
)
|
||||
def test_record_methods_handle_exceptions(
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
|
||||
) -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name=attr_name)
|
||||
@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
|
||||
def test_metrics_initializes_grpc_metric_exporter() -> None:
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
|
||||
assert isinstance(exporter, DummyGrpcMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert metric_reader.exporter.kwargs["insecure"] is False
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert exporter.kwargs["insecure"] is False
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
||||
assert isinstance(exporter, DummyHttpMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
|
||||
assert isinstance(exporter, DummyJsonMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in metric_reader.exporter.kwargs
|
||||
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
|
||||
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
|
||||
_ = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in metric_reader.exporter.kwargs
|
||||
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -31,13 +31,13 @@ class TestWeaveConfig:
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
WeaveConfig.model_validate({})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
WeaveConfig.model_validate({"api_key": "key"})
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
WeaveConfig.model_validate({"project": "project"})
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
|
||||
@ -6,9 +6,10 @@ requires-python = "~=3.12.0"
|
||||
dependencies = [
|
||||
# Legacy: mature and widely deployed
|
||||
"bleach>=6.3.0",
|
||||
"boto3>=1.42.91",
|
||||
"boto3>=1.42.96",
|
||||
"celery>=5.6.3",
|
||||
"croniter>=6.2.2",
|
||||
"flask>=3.1.3,<4.0.0",
|
||||
"flask-cors>=6.0.2",
|
||||
"gevent>=26.4.0",
|
||||
"gevent-websocket>=0.10.1",
|
||||
@ -16,7 +17,7 @@ dependencies = [
|
||||
"google-api-python-client>=2.194.0",
|
||||
"gunicorn>=25.3.0",
|
||||
"psycogreen>=1.0.2",
|
||||
"psycopg2-binary>=2.9.11",
|
||||
"psycopg2-binary>=2.9.12",
|
||||
"python-socketio>=5.13.0",
|
||||
"redis[hiredis]>=7.4.0",
|
||||
"sendgrid>=6.12.5",
|
||||
@ -32,13 +33,13 @@ dependencies = [
|
||||
"flask-restx>=1.3.2,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.148.1,<2.0.0",
|
||||
"httpx[socks]>=0.28.1,<1.0.0",
|
||||
"opentelemetry-distro>=0.62b0,<1.0.0",
|
||||
"opentelemetry-distro>=0.62b1,<1.0.0",
|
||||
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
|
||||
"opentelemetry-instrumentation-flask>=0.62b0,<1.0.0",
|
||||
"opentelemetry-instrumentation-httpx>=0.62b0,<1.0.0",
|
||||
"opentelemetry-instrumentation-redis>=0.62b0,<1.0.0",
|
||||
"opentelemetry-instrumentation-sqlalchemy>=0.62b0,<1.0.0",
|
||||
"opentelemetry-propagator-b3>=1.41.0,<2.0.0",
|
||||
"opentelemetry-propagator-b3>=1.41.1,<2.0.0",
|
||||
"readabilipy>=0.3.0,<1.0.0",
|
||||
"resend>=2.27.0,<3.0.0",
|
||||
|
||||
@ -117,7 +118,7 @@ dev = [
|
||||
"faker>=40.15.0",
|
||||
"lxml-stubs>=0.5.1",
|
||||
"basedpyright>=1.39.3",
|
||||
"ruff>=0.15.11",
|
||||
"ruff>=0.15.12",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-benchmark>=5.2.3",
|
||||
"pytest-cov>=7.1.0",
|
||||
@ -144,7 +145,7 @@ dev = [
|
||||
"types-pexpect>=4.9.0",
|
||||
"types-protobuf>=7.34.1",
|
||||
"types-psutil>=7.2.2",
|
||||
"types-psycopg2>=2.9.21",
|
||||
"types-psycopg2>=2.9.21.20260422",
|
||||
"types-pygments>=2.20.0",
|
||||
"types-pymysql>=1.1.0",
|
||||
"types-python-dateutil>=2.9.0",
|
||||
@ -157,9 +158,9 @@ dev = [
|
||||
"types-tensorflow>=2.18.0.20260408",
|
||||
"types-tqdm>=4.67.3.20260408",
|
||||
"types-ujson>=5.10.0",
|
||||
"boto3-stubs>=1.42.92",
|
||||
"boto3-stubs>=1.42.96",
|
||||
"types-jmespath>=1.1.0.20260408",
|
||||
"hypothesis>=6.152.1",
|
||||
"hypothesis>=6.152.3",
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=2.0.0.20260408",
|
||||
"types_setuptools>=82.0.0.20260408",
|
||||
@ -169,12 +170,12 @@ dev = [
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
"mypy>=1.20.1",
|
||||
"mypy>=1.20.2",
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.62.0",
|
||||
"xinference-client>=2.5.0",
|
||||
"xinference-client>=2.7.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
@ -184,12 +185,12 @@ dev = [
|
||||
storage = [
|
||||
"azure-storage-blob>=12.28.0",
|
||||
"bce-python-sdk>=0.9.70",
|
||||
"cos-python-sdk-v5>=1.9.41",
|
||||
"cos-python-sdk-v5>=1.9.42",
|
||||
"esdk-obs-python>=3.22.2",
|
||||
"google-cloud-storage>=3.10.1",
|
||||
"opendal>=0.46.0",
|
||||
"oss2>=2.19.1",
|
||||
"supabase>=2.28.3",
|
||||
"supabase>=2.29.0",
|
||||
"tos>=2.9.0",
|
||||
]
|
||||
|
||||
@ -272,7 +273,7 @@ vdb-vastbase = ["dify-vdb-vastbase"]
|
||||
vdb-vikingdb = ["dify-vdb-vikingdb"]
|
||||
vdb-weaviate = ["dify-vdb-weaviate"]
|
||||
# Optional client used by some tests / integrations (not a vector backend plugin)
|
||||
vdb-xinference = ["xinference-client>=2.5.0"]
|
||||
vdb-xinference = ["xinference-client>=2.7.0"]
|
||||
|
||||
trace-all = [
|
||||
"dify-trace-aliyun",
|
||||
|
||||
@ -133,7 +133,14 @@ class AppAnnotationService:
|
||||
raise ValueError("'question' is required when 'message_id' is not provided")
|
||||
question = maybe_question
|
||||
|
||||
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=None,
|
||||
message_id=None,
|
||||
content=answer,
|
||||
question=question,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -89,7 +89,10 @@ class AsyncWorkflowService:
|
||||
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
|
||||
|
||||
# 2. Get workflow
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
|
||||
|
||||
# commit read only session before starting the billig rpc call
|
||||
session.commit()
|
||||
|
||||
# 3. Get dispatcher based on tenant subscription
|
||||
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
||||
@ -302,13 +305,21 @@ class AsyncWorkflowService:
|
||||
return [log.to_dict() for log in logs]
|
||||
|
||||
@staticmethod
|
||||
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
|
||||
def _get_workflow(
|
||||
workflow_service: WorkflowService,
|
||||
app_model: App,
|
||||
workflow_id: str | None = None,
|
||||
session: Session | None = None,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Get workflow for the app
|
||||
|
||||
Args:
|
||||
app_model: App model instance
|
||||
workflow_id: Optional specific workflow ID
|
||||
session: Reuse this SQLAlchemy session for the lookup when provided,
|
||||
so the caller's explicit session bears the connection cost
|
||||
instead of Flask's request-scoped ``db.session``.
|
||||
|
||||
Returns:
|
||||
Workflow instance
|
||||
@ -318,12 +329,12 @@ class AsyncWorkflowService:
|
||||
"""
|
||||
if workflow_id:
|
||||
# Get specific published workflow
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
|
||||
else:
|
||||
# Get default published workflow
|
||||
workflow = workflow_service.get_published_workflow(app_model)
|
||||
workflow = workflow_service.get_published_workflow(app_model, session=session)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
|
||||
|
||||
|
||||
@ -3748,6 +3748,7 @@ class SegmentService:
|
||||
ChildChunk.segment_id == segment.id,
|
||||
)
|
||||
)
|
||||
assert current_user.current_tenant_id
|
||||
child_chunk = ChildChunk(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@ -3758,7 +3759,7 @@ class SegmentService:
|
||||
index_node_hash=index_node_hash,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
type="customized",
|
||||
type=SegmentType.CUSTOMIZED,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(child_chunk)
|
||||
@ -3818,6 +3819,7 @@ class SegmentService:
|
||||
if new_child_chunks_args:
|
||||
child_chunk_count = len(child_chunks)
|
||||
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
|
||||
assert current_user.current_tenant_id
|
||||
index_node_id = str(uuid.uuid4())
|
||||
index_node_hash = helper.generate_text_hash(args.content)
|
||||
child_chunk = ChildChunk(
|
||||
@ -3830,7 +3832,7 @@ class SegmentService:
|
||||
index_node_hash=index_node_hash,
|
||||
content=args.content,
|
||||
word_count=len(args.content),
|
||||
type="customized",
|
||||
type=SegmentType.CUSTOMIZED,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
|
||||
|
||||
@ -799,50 +799,47 @@ class WebhookService:
|
||||
Exception: If workflow execution fails
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Prepare inputs for the webhook node
|
||||
# The webhook node expects webhook_data in the inputs
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Create trigger data
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id,
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.webhook_id,
|
||||
)
|
||||
raise
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# reserve quota before triggering workflow execution
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.webhook_id,
|
||||
)
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
try:
|
||||
try:
|
||||
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
|
||||
# trigger_workflow_async need to handle multipe session commits internally
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
|
||||
@ -16,6 +16,7 @@ from extensions.ext_database import db
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import SegmentType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -178,7 +179,7 @@ class VectorService:
|
||||
index_node_hash=child_chunk.metadata["doc_hash"],
|
||||
content=child_chunk.page_content,
|
||||
word_count=len(child_chunk.page_content),
|
||||
type="automatic",
|
||||
type=SegmentType.AUTOMATIC,
|
||||
created_by=dataset_document.created_by,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
@ -222,6 +223,7 @@ class VectorService:
|
||||
)
|
||||
documents.append(new_child_document)
|
||||
for update_child_chunk in update_child_chunks:
|
||||
assert update_child_chunk.index_node_id
|
||||
child_document = Document(
|
||||
page_content=update_child_chunk.content,
|
||||
metadata={
|
||||
@ -234,6 +236,7 @@ class VectorService:
|
||||
documents.append(child_document)
|
||||
delete_node_ids.append(update_child_chunk.index_node_id)
|
||||
for delete_child_chunk in delete_child_chunks:
|
||||
assert delete_child_chunk.index_node_id
|
||||
delete_node_ids.append(delete_child_chunk.index_node_id)
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# update vector index
|
||||
@ -246,6 +249,7 @@ class VectorService:
|
||||
@classmethod
|
||||
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
|
||||
vector = Vector(dataset=dataset)
|
||||
assert child_chunk.index_node_id
|
||||
vector.delete_by_ids([child_chunk.index_node_id])
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -173,11 +173,18 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
|
||||
def get_published_workflow_by_id(
|
||||
self, app_model: App, workflow_id: str, session: Session | None = None
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
|
||||
When ``session`` is provided, reuse it so callers that already hold a
|
||||
Session avoid checking out an extra request-scoped ``db.session``
|
||||
connection. Falls back to ``db.session`` for backward compatibility.
|
||||
"""
|
||||
workflow = db.session.scalar(
|
||||
bind = session if session is not None else db.session
|
||||
workflow = bind.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
@ -195,16 +202,20 @@ class WorkflowService:
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Workflow | None:
|
||||
def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
|
||||
When ``session`` is provided, reuse it so callers that already hold a
|
||||
Session avoid checking out an extra request-scoped ``db.session``
|
||||
connection. Falls back to ``db.session`` for backward compatibility.
|
||||
"""
|
||||
|
||||
if not app_model.workflow_id:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = db.session.scalar(
|
||||
bind = session if session is not None else db.session
|
||||
workflow = bind.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
|
||||
@ -259,59 +259,58 @@ def dispatch_triggered_workflow(
|
||||
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
trigger_entity: TriggerProviderEntity = provider_controller.entity
|
||||
|
||||
# Ensure expire_on_commit is set to False to remain workflows available
|
||||
with session_factory.create_session() as session:
|
||||
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
|
||||
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=subscription.tenant_id,
|
||||
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
|
||||
user_id=user_id,
|
||||
)
|
||||
for plugin_trigger in subscribers:
|
||||
# Get workflow from mapping
|
||||
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
|
||||
if not workflow:
|
||||
logger.error(
|
||||
"Workflow not found for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=subscription.tenant_id,
|
||||
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Find the trigger node in the workflow
|
||||
event_node = None
|
||||
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
|
||||
if node_id == plugin_trigger.node_id:
|
||||
event_node = node_config
|
||||
break
|
||||
|
||||
if not event_node:
|
||||
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
|
||||
continue
|
||||
|
||||
# invoke trigger
|
||||
trigger_metadata = PluginTriggerMetadata(
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
event_name=event_name,
|
||||
icon_filename=trigger_entity.identity.icon or "",
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
for plugin_trigger in subscribers:
|
||||
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
|
||||
if not workflow:
|
||||
logger.error(
|
||||
"Workflow not found for app %s",
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# reserve quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
|
||||
)
|
||||
return 0
|
||||
event_node = None
|
||||
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
|
||||
if node_id == plugin_trigger.node_id:
|
||||
event_node = node_config
|
||||
break
|
||||
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
|
||||
invoke_response: TriggerInvokeEventResponse | None = None
|
||||
if not event_node:
|
||||
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
|
||||
continue
|
||||
|
||||
trigger_metadata = PluginTriggerMetadata(
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
|
||||
endpoint_id=subscription.endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
event_name=event_name,
|
||||
icon_filename=trigger_entity.identity.icon or "",
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id)
|
||||
return dispatched_count
|
||||
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
|
||||
invoke_response: TriggerInvokeEventResponse | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
invoke_response = TriggerManager.invoke_trigger_event(
|
||||
tenant_id=subscription.tenant_id,
|
||||
@ -403,7 +402,7 @@ def dispatch_triggered_workflow(
|
||||
plugin_trigger.app_id,
|
||||
)
|
||||
|
||||
return dispatched_count
|
||||
return dispatched_count
|
||||
|
||||
|
||||
def dispatch_triggered_workflows(
|
||||
|
||||
@ -33,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
TenantOwnerNotFoundError: If no owner/admin for tenant
|
||||
ScheduleExecutionError: If workflow trigger fails
|
||||
"""
|
||||
# Ensure expire_on_commit is set to False to remain schedule/tenant_owner available
|
||||
with session_factory.create_session() as session:
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
@ -42,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
if not tenant_owner:
|
||||
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
return
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
return
|
||||
|
||||
try:
|
||||
# Production dispatch: Trigger the workflow normally
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
response = AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=tenant_owner,
|
||||
@ -62,10 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
raise ScheduleExecutionError(
|
||||
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
|
||||
) from e
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
raise ScheduleExecutionError(
|
||||
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
|
||||
) from e
|
||||
|
||||
@ -171,35 +171,13 @@ class TestChatMessageApiPermissions:
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
class MockQuery:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
if getattr(self.model, "__name__", "") == "Conversation":
|
||||
return mock_conversation
|
||||
return None
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, *_):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
if getattr(self.model, "__name__", "") == "Message":
|
||||
return [mock_message]
|
||||
return []
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.query.side_effect = MockQuery
|
||||
mock_session.scalar.return_value = False
|
||||
mock_session.scalar.return_value = mock_conversation
|
||||
mock_session.scalars.return_value.all.return_value = [mock_message]
|
||||
|
||||
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock())
|
||||
|
||||
class DummyPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
|
||||
@ -24,7 +24,6 @@ def _patch_wraps():
|
||||
patch("controllers.console.wraps.dify_config", dify_settings),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,12 @@ from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
def _execute_result(rows):
|
||||
result = mock.Mock()
|
||||
result.all.return_value = rows
|
||||
return result
|
||||
|
||||
|
||||
class TestFeedbackService:
|
||||
"""Test FeedbackService methods."""
|
||||
|
||||
@ -81,25 +87,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in CSV format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test CSV export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@ -120,25 +118,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in JSON format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test JSON export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@ -157,25 +147,17 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with various filters."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
@ -193,17 +175,7 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback when no data exists."""
|
||||
|
||||
# Setup mock query result with no data
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result([])
|
||||
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
@ -251,24 +223,17 @@ class TestFeedbackService:
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
@ -309,24 +274,17 @@ class TestFeedbackService:
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None, # No account for user feedback
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
@ -339,32 +297,24 @@ class TestFeedbackService:
|
||||
|
||||
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
|
||||
"""Test that rating emojis are properly formatted in export."""
|
||||
|
||||
# Setup mock query result with both like and dislike feedback
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
|
||||
mock_db_session.execute.return_value = mock_query
|
||||
mock_db_session.execute.return_value = _execute_result(
|
||||
[
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
|
||||
from enums.quota_type import QuotaType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
@ -290,17 +291,26 @@ class TestWebhookServiceTriggerExecutionWithContainers:
|
||||
end_user = SimpleNamespace(id=str(uuid4()))
|
||||
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
|
||||
|
||||
quota_charge = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
|
||||
return_value=end_user,
|
||||
),
|
||||
patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume,
|
||||
patch(
|
||||
"services.trigger.webhook_service.QuotaService.reserve",
|
||||
return_value=quota_charge,
|
||||
) as mock_reserve,
|
||||
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
|
||||
):
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
mock_consume.assert_called_once_with(webhook_trigger.tenant_id)
|
||||
mock_reserve.assert_called_once()
|
||||
reserve_args = mock_reserve.call_args.args
|
||||
assert reserve_args[0] == QuotaType.TRIGGER
|
||||
assert reserve_args[1] == webhook_trigger.tenant_id
|
||||
quota_charge.commit.assert_called_once()
|
||||
mock_trigger.assert_called_once()
|
||||
trigger_args = mock_trigger.call_args.args
|
||||
assert trigger_args[1] is end_user
|
||||
@ -327,7 +337,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
|
||||
return_value=SimpleNamespace(id=str(uuid4())),
|
||||
),
|
||||
patch(
|
||||
"services.trigger.webhook_service.QuotaType.TRIGGER.consume",
|
||||
"services.trigger.webhook_service.QuotaService.reserve",
|
||||
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
|
||||
),
|
||||
patch(
|
||||
|
||||
@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine):
|
||||
configure_session_factory(_unit_test_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
|
||||
def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner):
|
||||
"""
|
||||
Helper to set up the mock DB execute chain for tenant/account authentication.
|
||||
Helper to stub the tenant-owner execute result for service API app authentication.
|
||||
|
||||
This configures the mock to return (tenant, account) for the
|
||||
db.session.execute(select(...).join().join().where()).one_or_none()
|
||||
query used by validate_app_token decorator.
|
||||
The validate_app_token decorator currently resolves the active tenant owner
|
||||
via db.session.execute(select(Tenant, Account)...).one_or_none().
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_account: Mock account object to return
|
||||
mock_owner: Mock owner object to return from the execute result
|
||||
"""
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner)
|
||||
|
||||
|
||||
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
|
||||
def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join):
|
||||
"""
|
||||
Helper to set up the mock DB execute chain for dataset tenant authentication.
|
||||
Helper to stub the tenant-owner execute result for dataset token authentication.
|
||||
|
||||
This configures the mock to return (tenant, tenant_account) for the
|
||||
db.session.execute(select(...).where().where().where().where()).one_or_none()
|
||||
query used by validate_dataset_token decorator.
|
||||
The validate_dataset_token decorator currently resolves the owner mapping via
|
||||
db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and
|
||||
then loads the Account separately via db.session.get(...).
|
||||
|
||||
Args:
|
||||
mock_db: The mocked db object
|
||||
mock_tenant: Mock tenant object to return
|
||||
mock_ta: Mock tenant account object to return
|
||||
mock_tenant_account_join: Mock tenant-account join object to return
|
||||
"""
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join)
|
||||
|
||||
@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with (
|
||||
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
|
||||
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
|
||||
@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
|
||||
|
||||
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
|
||||
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
|
||||
@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
|
||||
# Act
|
||||
@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
|
||||
@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "email_token_123"
|
||||
@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is allowed by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Registration is blocked by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = False
|
||||
@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Prevents spam and abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.side_effect = AccountRegisterError("Account frozen")
|
||||
|
||||
@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
|
||||
- Defaults to en-US when not specified
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in with token pair
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is logged in after account creation
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = None
|
||||
mock_create_account.return_value = mock_account
|
||||
@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
|
||||
- InvalidEmailError is raised when email doesn't match token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
|
||||
- EmailCodeError is raised for wrong verification code
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
|
||||
- User is added as owner of new workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
|
||||
- WorkspacesLimitExceeded is raised when limit reached
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
|
||||
- NotAllowedCreateWorkspace is raised when creation disabled
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
|
||||
@ -110,7 +110,6 @@ class TestLoginApi:
|
||||
- Rate limit is reset after successful login
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@ -162,7 +161,6 @@ class TestLoginApi:
|
||||
- Authentication proceeds with invitation token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
|
||||
mock_authenticate.return_value = mock_account
|
||||
@ -199,7 +197,6 @@ class TestLoginApi:
|
||||
- EmailPasswordLoginLimitError is raised when limit exceeded
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
@ -228,7 +225,6 @@ class TestLoginApi:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_frozen.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@ -268,7 +264,6 @@ class TestLoginApi:
|
||||
- Generic error message prevents user enumeration
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
|
||||
@ -305,7 +300,6 @@ class TestLoginApi:
|
||||
- Login is prevented even with valid credentials
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountLoginError("Account is banned")
|
||||
@ -351,7 +345,6 @@ class TestLoginApi:
|
||||
- User cannot login without an assigned workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
@ -383,7 +376,6 @@ class TestLoginApi:
|
||||
- Security check prevents invitation token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
|
||||
|
||||
@ -425,7 +417,6 @@ class TestLoginApi:
|
||||
mock_token_pair,
|
||||
):
|
||||
"""Test that login retries with lowercase email when uppercase lookup fails."""
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
|
||||
@ -459,7 +450,6 @@ class TestLoginApi:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_account.side_effect = Unauthorized("Account is banned.")
|
||||
|
||||
@ -513,7 +503,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
|
||||
# Act
|
||||
@ -539,7 +528,6 @@ class TestLogoutApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
# Create a mock anonymous user that will pass isinstance check
|
||||
anonymous_user = MagicMock()
|
||||
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
|
||||
|
||||
@ -46,7 +46,6 @@ class TestPartnerTenants:
|
||||
patch("libs.login.dify_config.LOGIN_DISABLED", False),
|
||||
patch("libs.login.check_csrf_token") as mock_csrf,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_csrf.return_value = None
|
||||
yield {"db": mock_db, "csrf": mock_csrf}
|
||||
|
||||
|
||||
@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden
|
||||
import controllers.console.tag.tags as module
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.tag.tags import (
|
||||
TagBindingCreateApi,
|
||||
TagBindingDeleteApi,
|
||||
DeprecatedTagBindingCreateApi,
|
||||
DeprecatedTagBindingRemoveApi,
|
||||
TagBindingCollectionApi,
|
||||
TagBindingItemApi,
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi:
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestTagBindingCreateApi:
|
||||
class TestTagBindingCollectionApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -232,7 +234,7 @@ class TestTagBindingCreateApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_create_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
api = TagBindingCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
@ -247,9 +249,78 @@ class TestTagBindingCreateApi:
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagBindingDeleteApi:
|
||||
class TestDeprecatedTagBindingCreateApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = DeprecatedTagBindingCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_ids": ["tag-1"],
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestTagBindingItemApi:
|
||||
def test_delete_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
delete_payload = delete_mock.call_args.args[0]
|
||||
assert delete_payload.tag_id == "tag-1"
|
||||
assert delete_payload.target_id == "target-1"
|
||||
assert delete_payload.type == TagType.KNOWLEDGE
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_forbidden(self, app, readonly_user):
|
||||
api = TagBindingItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "tag-1")
|
||||
|
||||
|
||||
class TestDeprecatedTagBindingRemoveApi:
|
||||
def test_remove_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
api = DeprecatedTagBindingRemoveApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -274,7 +345,7 @@ class TestTagBindingDeleteApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_remove_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
api = DeprecatedTagBindingRemoveApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
@ -297,3 +368,35 @@ class TestTagResponseModel:
|
||||
|
||||
assert payload["type"] == "knowledge"
|
||||
assert payload["binding_count"] == "1"
|
||||
|
||||
|
||||
class TestTagBindingRouteMetadata:
|
||||
def test_legacy_write_routes_are_marked_deprecated(self):
|
||||
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
|
||||
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
|
||||
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
|
||||
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
|
||||
|
||||
def test_write_routes_have_stable_operation_ids(self):
|
||||
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
|
||||
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
|
||||
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
|
||||
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
|
||||
|
||||
def test_canonical_and_legacy_write_routes_are_registered(self):
|
||||
route_map = {
|
||||
resource.__name__: urls
|
||||
for resource, urls, _route_doc, _kwargs in console_ns.resources
|
||||
if resource.__name__
|
||||
in {
|
||||
"TagBindingCollectionApi",
|
||||
"TagBindingItemApi",
|
||||
"DeprecatedTagBindingCreateApi",
|
||||
"DeprecatedTagBindingRemoveApi",
|
||||
}
|
||||
}
|
||||
|
||||
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
|
||||
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
|
||||
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
|
||||
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)
|
||||
|
||||
@ -24,10 +24,6 @@ def app():
|
||||
return app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
account = Account(name=account_id, email=email)
|
||||
@ -64,7 +60,6 @@ class TestChangeEmailSend:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@ -117,7 +112,6 @@ class TestChangeEmailSend:
|
||||
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@ -163,7 +157,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
@ -223,7 +216,6 @@ class TestChangeEmailValidity:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@ -280,7 +272,6 @@ class TestChangeEmailValidity:
|
||||
"""A token whose phase marker is a string but not a known transition must be rejected."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@ -330,7 +321,6 @@ class TestChangeEmailValidity:
|
||||
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
@ -378,7 +368,6 @@ class TestChangeEmailReset:
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@ -434,7 +423,6 @@ class TestChangeEmailReset:
|
||||
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@ -488,7 +476,6 @@ class TestChangeEmailReset:
|
||||
"""A verified token for address A must not be replayed to change to address B."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
@ -578,7 +564,6 @@ class TestCheckEmailUnique:
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
@ -16,10 +16,6 @@ def app():
|
||||
return flask_app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_feature_flags():
|
||||
placeholder_quota = SimpleNamespace(limit=0, size=0)
|
||||
workspace_members = SimpleNamespace(is_available=lambda count: True)
|
||||
@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
|
||||
mock_get_features,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
|
||||
@ -310,7 +310,6 @@ class TestSystemSetup:
|
||||
def test_should_allow_when_setup_complete(self, mock_db):
|
||||
"""Test that requests are allowed when setup is complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
|
||||
|
||||
@setup_required
|
||||
def admin_view():
|
||||
|
||||
@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
@ -74,7 +74,7 @@ class TestAppParameterApi:
|
||||
# Mock tenant owner info for login
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -120,7 +120,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -161,7 +161,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -200,7 +200,7 @@ class TestAppParameterApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -263,7 +263,7 @@ class TestAppMetaApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -331,7 +331,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -388,7 +388,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -434,7 +434,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
@ -486,7 +486,7 @@ class TestAppInfoApi:
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
|
||||
@ -15,7 +15,10 @@ from flask import Flask
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode, EndUser
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -123,9 +126,7 @@ class AuthenticationMocker:
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
if mock_account:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@staticmethod
|
||||
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
|
||||
@ -133,8 +134,7 @@ class AuthenticationMocker:
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
|
||||
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
|
||||
|
||||
@ -701,8 +701,8 @@ class TestDocumentApiDelete:
|
||||
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
|
||||
internally calls ``validate_and_get_api_token``. To bypass the decorator
|
||||
we call the original function via ``__wrapped__`` (preserved by
|
||||
``functools.wraps``). ``delete`` queries the dataset via
|
||||
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
|
||||
``functools.wraps``). ``delete`` loads the dataset via
|
||||
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
|
||||
controller module.
|
||||
"""
|
||||
|
||||
|
||||
@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
from models.model import ApiToken
|
||||
from tests.unit_tests.conftest import (
|
||||
setup_mock_dataset_tenant_query,
|
||||
setup_mock_tenant_account_query,
|
||||
setup_mock_dataset_owner_execute_result,
|
||||
setup_mock_tenant_owner_execute_result,
|
||||
)
|
||||
|
||||
|
||||
@ -141,14 +141,11 @@ class TestValidateAppToken:
|
||||
mock_account = Mock()
|
||||
mock_account.id = str(uuid.uuid4())
|
||||
|
||||
mock_ta = Mock()
|
||||
mock_ta.account_id = mock_account.id
|
||||
|
||||
# Use side_effect to return app first, then tenant via session.get()
|
||||
mock_db.session.get.side_effect = [mock_app, mock_tenant]
|
||||
|
||||
# Mock the tenant owner query (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
|
||||
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
|
||||
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
|
||||
|
||||
@validate_app_token
|
||||
def protected_view(app_model):
|
||||
@ -471,7 +468,7 @@ class TestValidateDatasetToken:
|
||||
mock_account.current_tenant = mock_tenant
|
||||
|
||||
# Mock the tenant account join query (execute(select(...)).one_or_none())
|
||||
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
|
||||
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
|
||||
|
||||
# Mock the account lookup via session.get()
|
||||
mock_db.session.get.return_value = mock_account
|
||||
|
||||
@ -22,18 +22,16 @@ class FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
def get(self, model: type, _ident: object) -> Any:
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
def scalar(self, stmt: Any) -> Any:
|
||||
try:
|
||||
model = stmt.column_descriptions[0]["entity"]
|
||||
except (AttributeError, IndexError, KeyError, TypeError):
|
||||
return None
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
|
||||
@ -36,18 +36,6 @@ class _FakeSession:
|
||||
|
||||
def __init__(self, mapping: dict[str, Any]):
|
||||
self._mapping = mapping
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model):
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
def get(self, model, ident):
|
||||
return self._mapping.get(model.__name__)
|
||||
|
||||
@ -34,7 +34,6 @@ def _patch_wraps():
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
patch("controllers.web.login.dify_config", web_dify),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
|
||||
@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker):
|
||||
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.get.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
|
||||
@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker):
|
||||
app_generate_entity.single_iteration_run = None
|
||||
app_generate_entity.single_loop_run = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.get.side_effect = [None, None]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_pipeline
|
||||
session.get.side_effect = [None, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
|
||||
@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime:
|
||||
"last_edited_time": "2024-11-27T18:00:00.000Z",
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
extractor_page.update_last_edited_time(mock_document_model)
|
||||
@ -863,9 +860,6 @@ class TestNotionExtractorIntegration:
|
||||
}
|
||||
|
||||
mock_request.side_effect = [last_edited_response, block_response]
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
documents = extractor.extract()
|
||||
@ -919,10 +913,6 @@ class TestNotionExtractorIntegration:
|
||||
}
|
||||
mock_post.return_value = database_response
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
documents = extractor.extract()
|
||||
|
||||
|
||||
@ -40,11 +40,11 @@ class TestObfuscatedToken:
|
||||
class TestEncryptToken:
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
def test_successful_encryption(self, mock_encrypt, mock_get):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
@ -53,9 +53,9 @@ class TestEncryptToken:
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
def test_tenant_not_found(self, mock_get):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value = None
|
||||
mock_get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration:
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
@ -148,12 +148,12 @@ class TestSecurity:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_get):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
@ -183,10 +183,10 @@ class TestSecurity:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_get):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
@ -207,11 +207,11 @@ class TestEdgeCases:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
@ -221,11 +221,11 @@ class TestEdgeCases:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
@ -244,11 +244,11 @@ class TestEdgeCases:
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_get.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
|
||||
@ -495,7 +495,7 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
|
||||
|
||||
@ -521,7 +521,7 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
# Cause exception in node_type logic
|
||||
workflow.graph_dict = {"graph": {"nodes": []}}
|
||||
@ -548,7 +548,7 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||
|
||||
@ -636,7 +636,7 @@ class TestLLMGenerator:
|
||||
instance.invoke_llm.return_value = mock_response
|
||||
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}
|
||||
|
||||
|
||||
@ -29,15 +29,6 @@ class _Field:
|
||||
return ("in", self._name, tuple(values))
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self):
|
||||
self.where_calls: list[tuple] = []
|
||||
|
||||
def where(self, *conditions):
|
||||
self.where_calls.append(conditions)
|
||||
return self
|
||||
|
||||
|
||||
class _FakeExecuteResult:
|
||||
def __init__(self, segments: list[SimpleNamespace]):
|
||||
self._segments = segments
|
||||
|
||||
@ -109,17 +109,6 @@ class _FakeExecuteResult:
|
||||
return _FakeExecuteScalarResult(self._data)
|
||||
|
||||
|
||||
class _FakeSummaryQuery:
|
||||
def __init__(self, summaries: list) -> None:
|
||||
self._summaries = summaries
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def all(self) -> list:
|
||||
return self._summaries
|
||||
|
||||
|
||||
class _FakeScalarsResult:
|
||||
def __init__(self, data: list) -> None:
|
||||
self._data = data
|
||||
|
||||
@ -372,19 +372,11 @@ def test_vector_delegation_methods(vector_factory_module):
|
||||
|
||||
|
||||
def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch):
|
||||
class _Field:
|
||||
def __eq__(self, value):
|
||||
return value
|
||||
|
||||
upload_query = MagicMock()
|
||||
upload_query.where.return_value = upload_query
|
||||
|
||||
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
|
||||
vector._embeddings = MagicMock()
|
||||
vector._vector_processor = MagicMock()
|
||||
|
||||
mock_session = SimpleNamespace(get=lambda _model, _id: None)
|
||||
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
|
||||
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
|
||||
|
||||
assert vector.search_by_file("file-1") == []
|
||||
|
||||
@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk:
|
||||
|
||||
mock_dependencies["redis"].get.return_value = None
|
||||
|
||||
# Mock database query for segment updates
|
||||
mock_query = MagicMock()
|
||||
mock_dependencies["db"].session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.update.return_value = None
|
||||
# Mock database update for segment status
|
||||
mock_dependencies["db"].session.execute.return_value = None
|
||||
|
||||
# Create a proper context manager mock
|
||||
mock_context = MagicMock()
|
||||
|
||||
@ -2417,12 +2417,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
|
||||
mock_document.data_source_type = "upload_file"
|
||||
mock_document.doc_metadata = {}
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_dataset_from_db
|
||||
]
|
||||
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
|
||||
[mock_dataset_from_db, mock_document]
|
||||
)
|
||||
mock_datasets = MagicMock()
|
||||
mock_datasets.all.return_value = [mock_dataset_from_db]
|
||||
mock_documents = MagicMock()
|
||||
mock_documents.all.return_value = [mock_document]
|
||||
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
|
||||
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
@ -451,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
|
||||
mock_document.data_source_type = "upload_file"
|
||||
mock_document.doc_metadata = {}
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_dataset_from_db
|
||||
]
|
||||
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
|
||||
[mock_dataset_from_db, mock_document]
|
||||
)
|
||||
mock_datasets = MagicMock()
|
||||
mock_datasets.all.return_value = [mock_dataset_from_db]
|
||||
mock_documents = MagicMock()
|
||||
mock_documents.all.return_value = [mock_document]
|
||||
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
|
||||
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
@ -711,6 +711,8 @@ class TestMessageAnnotation:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app_id,
|
||||
question="What is AI?",
|
||||
conversation_id=None,
|
||||
message_id=None,
|
||||
content="AI stands for Artificial Intelligence.",
|
||||
account_id=account_id,
|
||||
)
|
||||
@ -728,6 +730,8 @@ class TestMessageAnnotation:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=str(uuid4()),
|
||||
question="Test question",
|
||||
conversation_id=None,
|
||||
message_id=None,
|
||||
content="Test content",
|
||||
account_id=str(uuid4()),
|
||||
)
|
||||
@ -1068,6 +1072,8 @@ class TestModelIntegration:
|
||||
app_id=app_id,
|
||||
question="What is AI?",
|
||||
content="AI stands for Artificial Intelligence.",
|
||||
conversation_id=None,
|
||||
message_id=message_id,
|
||||
account_id=account_id,
|
||||
)
|
||||
annotation.id = annotation_id
|
||||
|
||||
@ -365,7 +365,6 @@ def _make_segment(
|
||||
|
||||
def _make_child_chunk() -> ChildChunk:
|
||||
return ChildChunk(
|
||||
id="child-a",
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
document_id="doc-1",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,925 +0,0 @@
|
||||
"""
|
||||
Extensive unit tests for ``ExternalDatasetService``.
|
||||
|
||||
This module focuses on the *external dataset service* surface area, which is responsible
|
||||
for integrating with **external knowledge APIs** and wiring them into Dify datasets.
|
||||
|
||||
The goal of this test suite is twofold:
|
||||
|
||||
- Provide **high‑confidence regression coverage** for all public helpers on
|
||||
``ExternalDatasetService``.
|
||||
- Serve as **executable documentation** for how external API integration is expected
|
||||
to behave in different scenarios (happy paths, validation failures, and error codes).
|
||||
|
||||
The file intentionally contains **rich comments and generous spacing** in order to make
|
||||
each scenario easy to scan during reviews.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
|
||||
from services.entities.external_knowledge_entities.external_knowledge_entities import (
|
||||
Authorization,
|
||||
AuthorizationConfig,
|
||||
ExternalKnowledgeApiSetting,
|
||||
)
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
class ExternalDatasetTestDataFactory:
|
||||
"""
|
||||
Factory helpers for building *lightweight* mocks for external knowledge tests.
|
||||
|
||||
These helpers are intentionally small and explicit:
|
||||
|
||||
- They avoid pulling in unnecessary fixtures.
|
||||
- They reflect the minimal contract that the service under test cares about.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_external_api(
|
||||
api_id: str = "api-123",
|
||||
tenant_id: str = "tenant-1",
|
||||
name: str = "Test API",
|
||||
description: str = "Description",
|
||||
settings: dict[str, Any] | None = None,
|
||||
) -> ExternalKnowledgeApis:
|
||||
"""
|
||||
Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
|
||||
|
||||
Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
|
||||
exercise ``settings_dict`` and other convenience properties if needed.
|
||||
"""
|
||||
|
||||
instance = ExternalKnowledgeApis(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
|
||||
)
|
||||
|
||||
# Overwrite generated id for determinism in assertions.
|
||||
instance.id = api_id
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
dataset_id: str = "ds-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
name: str = "External Dataset",
|
||||
provider: str = "external",
|
||||
) -> Dataset:
|
||||
"""
|
||||
Build a small ``Dataset`` instance representing an external dataset.
|
||||
"""
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="",
|
||||
provider=provider,
|
||||
created_by="user-1",
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding(
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "ds-1",
|
||||
api_id: str = "api-1",
|
||||
external_knowledge_id: str = "knowledge-1",
|
||||
) -> ExternalKnowledgeBindings:
|
||||
"""
|
||||
Small helper for a binding between dataset and external knowledge API.
|
||||
"""
|
||||
|
||||
binding = ExternalKnowledgeBindings(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
external_knowledge_api_id=api_id,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
created_by="user-1",
|
||||
)
|
||||
return binding
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_knowledge_apis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceGetExternalKnowledgeApis:
|
||||
"""
|
||||
Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
|
||||
|
||||
These tests focus on:
|
||||
|
||||
- Basic pagination wiring via ``db.paginate``.
|
||||
- Optional search keyword behaviour.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_paginate(self):
|
||||
"""
|
||||
Patch ``db.paginate`` so we do not touch the real database layer.
|
||||
"""
|
||||
|
||||
with (
|
||||
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
|
||||
patch("services.external_knowledge_service.select", autospec=True),
|
||||
):
|
||||
yield mock_paginate
|
||||
|
||||
def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
|
||||
"""
|
||||
It should return ``items`` and ``total`` coming from the paginate object.
|
||||
"""
|
||||
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
|
||||
mock_pagination = SimpleNamespace(items=mock_items, total=42)
|
||||
mock_db_paginate.return_value = mock_pagination
|
||||
|
||||
# Act
|
||||
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert items is mock_items
|
||||
assert total == 42
|
||||
|
||||
mock_db_paginate.assert_called_once()
|
||||
call_kwargs = mock_db_paginate.call_args.kwargs
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == per_page
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
assert call_kwargs["error_out"] is False
|
||||
|
||||
def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
|
||||
"""
|
||||
When a search keyword is provided, the query should be adjusted
|
||||
(we simply assert that paginate is still called and does not explode).
|
||||
"""
|
||||
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
page = 2
|
||||
per_page = 10
|
||||
search = "foo"
|
||||
|
||||
mock_pagination = SimpleNamespace(items=[], total=0)
|
||||
mock_db_paginate.return_value = mock_pagination
|
||||
|
||||
# Act
|
||||
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
|
||||
|
||||
# Assert
|
||||
assert items == []
|
||||
assert total == 0
|
||||
mock_db_paginate.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_api_list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceValidateApiList:
|
||||
"""
|
||||
Lightweight validation tests for ``validate_api_list``.
|
||||
"""
|
||||
|
||||
def test_validate_api_list_success(self):
|
||||
"""
|
||||
A minimal valid configuration (endpoint + api_key) should pass.
|
||||
"""
|
||||
|
||||
config = {"endpoint": "https://example.com", "api_key": "secret"}
|
||||
|
||||
# Act & Assert – no exception expected
|
||||
ExternalDatasetService.validate_api_list(config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("config", "expected_message"),
|
||||
[
|
||||
({}, "api list is empty"),
|
||||
({"api_key": "k"}, "endpoint is required"),
|
||||
({"endpoint": "https://example.com"}, "api_key is required"),
|
||||
],
|
||||
)
|
||||
def test_validate_api_list_failures(self, config: dict[str, Any], expected_message: str):
|
||||
"""
|
||||
Invalid configs should raise ``ValueError`` with a clear message.
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError, match=expected_message):
|
||||
ExternalDatasetService.validate_api_list(config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_external_knowledge_api & get/update/delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
"""
|
||||
CRUD tests for external knowledge API templates.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Patch ``db.session`` for all CRUD tests in this class.
|
||||
"""
|
||||
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``create_external_knowledge_api`` should persist a new record
|
||||
when settings are present and valid.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
args = {
|
||||
"name": "API",
|
||||
"description": "desc",
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
|
||||
}
|
||||
|
||||
# We do not want to actually call the remote endpoint here, so we patch the validator.
|
||||
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
|
||||
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApis)
|
||||
mock_check.assert_called_once_with(args["settings"])
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing ``settings`` should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
args = {"name": "API", "description": "desc"}
|
||||
|
||||
with pytest.raises(ValueError, match="settings is required"):
|
||||
ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``get_external_knowledge_api`` should return the first matching record.
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
|
||||
assert result is api
|
||||
|
||||
def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the record is absent, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
|
||||
|
||||
def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Updating an API should keep the existing API key when the special hidden
|
||||
value placeholder is sent from the UI.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
api_id = "api-1"
|
||||
|
||||
existing_api = Mock(spec=ExternalKnowledgeApis)
|
||||
existing_api.settings_dict = {"api_key": "stored-key"}
|
||||
existing_api.settings = '{"api_key":"stored-key"}'
|
||||
mock_db_session.scalar.return_value = existing_api
|
||||
|
||||
args = {
|
||||
"name": "New Name",
|
||||
"description": "New Desc",
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
|
||||
}
|
||||
|
||||
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
|
||||
|
||||
assert result is existing_api
|
||||
# The placeholder should be replaced with stored key.
|
||||
assert args["settings"]["api_key"] == "stored-key"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Updating a non‑existent API template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
external_knowledge_api_id="missing-id",
|
||||
args={"name": "n", "description": "d", "settings": {}},
|
||||
)
|
||||
|
||||
def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``delete_external_knowledge_api`` should delete and commit when found.
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.scalar.return_value = api
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
|
||||
|
||||
mock_db_session.delete.assert_called_once_with(api)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Deletion of a missing template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# external_knowledge_api_use_check & binding lookups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceUsageAndBindings:
|
||||
"""
|
||||
Tests for usage checks and dataset binding retrieval.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = 3
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is True
|
||||
assert count == 3
|
||||
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
|
||||
|
||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Zero bindings should return ``(False, 0)``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = 0
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is False
|
||||
assert count == 0
|
||||
|
||||
def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Binding lookup should return the first record when present.
|
||||
"""
|
||||
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
mock_db_session.scalar.return_value = binding
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
assert result is binding
|
||||
|
||||
def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing binding should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# document_create_args_validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
"""
|
||||
Tests for ``document_create_args_validate``.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
All required custom parameters present – validation should pass.
|
||||
"""
|
||||
|
||||
external_api = Mock(spec=ExternalKnowledgeApis)
|
||||
external_api.settings = json_settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
# Raw string; the service itself calls json.loads on it
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"foo": "value", "bar": "optional"}
|
||||
|
||||
# Act & Assert – no exception
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
|
||||
|
||||
assert json_settings in external_api.settings # simple sanity check on our test data
|
||||
|
||||
def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the referenced API template is missing, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
|
||||
|
||||
def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Required document process parameters must be supplied.
|
||||
"""
|
||||
|
||||
external_api = Mock(spec=ExternalKnowledgeApis)
|
||||
external_api.settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
mock_db_session.scalar.return_value = external_api
|
||||
|
||||
process_parameter = {"bar": "present"} # missing "foo"
|
||||
|
||||
with pytest.raises(ValueError, match="foo is required"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_external_api
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceProcessExternalApi:
|
||||
"""
|
||||
Tests focused on the HTTP request assembly and method mapping behaviour.
|
||||
"""
|
||||
|
||||
def test_process_external_api_valid_method_post(self):
|
||||
"""
|
||||
For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
|
||||
"""
|
||||
|
||||
settings = ExternalKnowledgeApiSetting(
|
||||
url="https://example.com/path",
|
||||
request_method="POST",
|
||||
headers={"X-Test": "1"},
|
||||
params={"foo": "bar"},
|
||||
)
|
||||
|
||||
fake_response = httpx.Response(200)
|
||||
|
||||
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
|
||||
mock_post.return_value = fake_response
|
||||
|
||||
result = ExternalDatasetService.process_external_api(settings, files=None)
|
||||
|
||||
assert result is fake_response
|
||||
mock_post.assert_called_once()
|
||||
kwargs = mock_post.call_args.kwargs
|
||||
assert kwargs["url"] == settings.url
|
||||
assert kwargs["headers"] == settings.headers
|
||||
assert kwargs["follow_redirects"] is True
|
||||
assert "data" in kwargs
|
||||
|
||||
def test_process_external_api_invalid_method_raises(self):
|
||||
"""
|
||||
An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
|
||||
"""
|
||||
|
||||
settings = ExternalKnowledgeApiSetting(
|
||||
url="https://example.com",
|
||||
request_method="INVALID",
|
||||
headers=None,
|
||||
params={},
|
||||
)
|
||||
|
||||
from graphon.nodes.http_request.exc import InvalidHttpMethodError
|
||||
|
||||
with pytest.raises(InvalidHttpMethodError):
|
||||
ExternalDatasetService.process_external_api(settings, files=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# assembling_headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceAssemblingHeaders:
|
||||
"""
|
||||
Tests for header assembly based on different authentication flavours.
|
||||
"""
|
||||
|
||||
def test_assembling_headers_bearer_token(self):
|
||||
"""
|
||||
For bearer auth we expect ``Authorization: Bearer <key>`` by default.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
assert headers["Authorization"] == "Bearer secret"
|
||||
|
||||
def test_assembling_headers_basic_token_with_custom_header(self):
|
||||
"""
|
||||
For basic auth we honour the configured header name.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
|
||||
|
||||
assert headers["Existing"] == "1"
|
||||
assert headers["X-Auth"] == "Basic abc123"
|
||||
|
||||
def test_assembling_headers_custom_type(self):
|
||||
"""
|
||||
Custom auth type should inject the raw API key.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth, headers=None)
|
||||
|
||||
assert headers["X-API-KEY"] == "raw-key"
|
||||
|
||||
def test_assembling_headers_missing_config_raises(self):
|
||||
"""
|
||||
Missing config object should be rejected.
|
||||
"""
|
||||
|
||||
auth = Authorization(type="api-key", config=None)
|
||||
|
||||
with pytest.raises(ValueError, match="authorization config is required"):
|
||||
ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
def test_assembling_headers_missing_api_key_raises(self):
|
||||
"""
|
||||
``api_key`` is required when type is ``api-key``.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
|
||||
"""
|
||||
For ``no-auth`` we should not modify the headers mapping.
|
||||
"""
|
||||
|
||||
auth = Authorization(type="no-auth", config=None)
|
||||
|
||||
base_headers = {"X": "1"}
|
||||
result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
|
||||
|
||||
# A copy is returned, original is not mutated.
|
||||
assert result == base_headers
|
||||
assert result is not base_headers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_knowledge_api_settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
|
||||
"""
|
||||
Simple shape test for ``get_external_knowledge_api_settings``.
|
||||
"""
|
||||
|
||||
def test_get_external_knowledge_api_settings(self):
|
||||
settings_dict: dict[str, Any] = {
|
||||
"url": "https://example.com/retrieval",
|
||||
"request_method": "post",
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"params": {"foo": "bar"},
|
||||
}
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApiSetting)
|
||||
assert result.url == settings_dict["url"]
|
||||
assert result.request_method == settings_dict["request_method"]
|
||||
assert result.headers == settings_dict["headers"]
|
||||
assert result.params == settings_dict["params"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_external_dataset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceCreateExternalDataset:
|
||||
"""
|
||||
Tests around creating the external dataset and its binding row.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
A brand new dataset name with valid external knowledge references
|
||||
should create both the dataset and its binding.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
|
||||
args = {
|
||||
"name": "My Dataset",
|
||||
"description": "desc",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
"external_retrieval_model": {"top_k": 3},
|
||||
}
|
||||
|
||||
# No existing dataset with same name.
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None, # duplicate‑name check
|
||||
Mock(spec=ExternalKnowledgeApis), # external knowledge api
|
||||
]
|
||||
|
||||
dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(dataset, Dataset)
|
||||
assert dataset.provider == "external"
|
||||
assert dataset.retrieval_model == args["external_retrieval_model"]
|
||||
|
||||
assert mock_db_session.add.call_count >= 2 # dataset + binding
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When a dataset with the same name already exists,
|
||||
``DatasetNameDuplicateError`` is raised.
|
||||
"""
|
||||
|
||||
existing_dataset = Mock(spec=Dataset)
|
||||
mock_db_session.scalar.return_value = existing_dataset
|
||||
|
||||
args = {
|
||||
"name": "Existing",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
}
|
||||
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
# First call: duplicate name check – not found.
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
None, # external knowledge api lookup
|
||||
]
|
||||
|
||||
args = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": "missing",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
|
||||
|
||||
def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
|
||||
"""
|
||||
|
||||
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
|
||||
mock_db_session.scalar.side_effect = [
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
]
|
||||
|
||||
args_missing_knowledge_id = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": None,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="external_knowledge_id is required"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
|
||||
|
||||
args_missing_api_id = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": None,
|
||||
"external_knowledge_id": "k-1",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_external_knowledge_retrieval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
"""
|
||||
Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
|
||||
external retrieval requests and normalises the response payload.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
With a valid binding and API template, records from the external
|
||||
service should be returned when the HTTP response is 200.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
dataset_id = "ds-1"
|
||||
query = "test query"
|
||||
external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
api_id="api-1",
|
||||
external_knowledge_id="knowledge-1",
|
||||
)
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
# First query: binding; second query: api.
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
fake_records = [{"content": "doc", "score": 0.9}]
|
||||
fake_response = Mock(spec=httpx.Response)
|
||||
fake_response.status_code = 200
|
||||
fake_response.json.return_value = {"records": fake_records}
|
||||
|
||||
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
|
||||
|
||||
with patch.object(
|
||||
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
|
||||
) as mock_process:
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=external_retrieval_parameters,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
|
||||
assert result == fake_records
|
||||
|
||||
mock_process.assert_called_once()
|
||||
setting_arg = mock_process.call_args.args[0]
|
||||
assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
|
||||
assert setting_arg.url.endswith("/retrieval")
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing binding should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="missing",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the API template is missing or has no settings, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
None,
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="external api template not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Non‑200 responses should be treated as an empty result set.
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
mock_db_session.scalar.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
fake_response = Mock(spec=httpx.Response)
|
||||
fake_response.status_code = 500
|
||||
fake_response.json.return_value = {}
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
@ -374,24 +374,14 @@ def test_publish_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
mock_db = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db", mock_db)
|
||||
mock_dataset_service_class = mocker.patch("services.dataset_service.DatasetService")
|
||||
mock_dataset_service = mock_dataset_service_class.return_value
|
||||
|
||||
# 6. Mock session and its scalar/query methods
|
||||
# 6. Mock session and dataset lookup
|
||||
mock_session = mocker.Mock()
|
||||
mock_session.scalar.return_value = draft_wf
|
||||
|
||||
# Mock dataset update query (needed even if service is mocked, as rag_pipeline fetches it first)
|
||||
dataset = mocker.Mock()
|
||||
dataset.retrieval_model_dict = {}
|
||||
dataset_query = mocker.Mock()
|
||||
dataset_query.where.return_value.first.return_value = dataset
|
||||
|
||||
# Mock node execution copy
|
||||
node_exec_query = mocker.Mock()
|
||||
node_exec_query.where.return_value.all.return_value = []
|
||||
|
||||
# Mocked session query side effects
|
||||
mock_session.query.side_effect = [node_exec_query, dataset_query]
|
||||
pipeline.retrieve_dataset.return_value = dataset
|
||||
|
||||
# 7. Run test
|
||||
result = rag_pipeline_service.publish_workflow(session=mock_session, pipeline=pipeline, account=account)
|
||||
@ -1524,7 +1514,6 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
|
||||
)
|
||||
|
||||
document = SimpleNamespace(indexing_status="waiting", error=None)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
|
||||
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
@ -1595,7 +1584,6 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
|
||||
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
@ -1604,7 +1592,6 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
|
||||
|
||||
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
@ -1644,7 +1631,6 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
|
||||
|
||||
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
@ -1871,7 +1857,6 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
@ -1910,7 +1895,6 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
|
||||
|
||||
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
|
||||
|
||||
@ -1923,7 +1907,6 @@ def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_
|
||||
def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
exec_log = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
|
||||
@ -1940,7 +1923,6 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
|
||||
)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
|
||||
@ -2103,7 +2085,6 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]},
|
||||
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
@ -2143,7 +2124,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
|
||||
{"variable": "v3", "belong_to_node_id": "shared"},
|
||||
],
|
||||
)
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
|
||||
mocker.patch(
|
||||
@ -2161,7 +2141,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
|
||||
def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
query = mocker.Mock()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
|
||||
result = rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,59 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class ServiceDbTestHelper:
|
||||
"""
|
||||
Helper class for service database query tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def setup_db_query_filter_by_mock(mock_db, query_results):
|
||||
"""
|
||||
Smart database query mock that responds based on model type and query parameters.
|
||||
|
||||
Args:
|
||||
mock_db: Mock database session
|
||||
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
|
||||
Example: {('Account', 'email', 'test@example.com'): mock_account}
|
||||
"""
|
||||
|
||||
def query_side_effect(model):
|
||||
mock_query = MagicMock()
|
||||
|
||||
def filter_by_side_effect(**kwargs):
|
||||
mock_filter_result = MagicMock()
|
||||
|
||||
def first_side_effect():
|
||||
# Find matching result based on model and filter parameters
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_filter_result.first.side_effect = first_side_effect
|
||||
|
||||
# Handle order_by calls for complex queries
|
||||
def order_by_side_effect(*args, **kwargs):
|
||||
mock_order_result = MagicMock()
|
||||
|
||||
def order_first_side_effect():
|
||||
# Look for order_by results in the same query_results dict
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if (
|
||||
model.__name__ == model_name
|
||||
and filter_key == "order_by"
|
||||
and filter_value == "first_available"
|
||||
):
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_order_result.first.side_effect = order_first_side_effect
|
||||
return mock_order_result
|
||||
|
||||
mock_filter_result.order_by.side_effect = order_by_side_effect
|
||||
return mock_filter_result
|
||||
|
||||
mock_query.filter_by.side_effect = filter_by_side_effect
|
||||
return mock_query
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
@ -14,7 +14,6 @@ from services.errors.account import (
|
||||
AccountRegisterError,
|
||||
CurrentPasswordIncorrectError,
|
||||
)
|
||||
from tests.unit_tests.services.services_test_help import ServiceDbTestHelper
|
||||
|
||||
|
||||
class TestAccountAssociatedDataFactory:
|
||||
@ -149,7 +148,6 @@ class TestAccountService:
|
||||
# Setup basic session methods
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.query = MagicMock()
|
||||
|
||||
yield mock_db
|
||||
|
||||
@ -1572,15 +1570,9 @@ class TestRegisterService:
|
||||
account_id="existing-user-456", email="existing@example.com", status="active"
|
||||
)
|
||||
|
||||
# Mock database queries
|
||||
query_results = {
|
||||
(
|
||||
"TenantAccountJoin",
|
||||
"tenant_id",
|
||||
"tenant-456",
|
||||
): TestAccountAssociatedDataFactory.create_tenant_join_mock(),
|
||||
}
|
||||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
mock_db_dependencies[
|
||||
"db"
|
||||
].session.scalar.return_value = TestAccountAssociatedDataFactory.create_tenant_join_mock()
|
||||
|
||||
# Mock TenantService methods
|
||||
with (
|
||||
|
||||
@ -238,6 +238,8 @@ class TestAppAnnotationServiceUpInsert:
|
||||
assert result == annotation_instance
|
||||
mock_cls.assert_called_once_with(
|
||||
app_id=app.id,
|
||||
conversation_id=None,
|
||||
message_id=None,
|
||||
content="hello",
|
||||
question="q1",
|
||||
account_id=current_user.id,
|
||||
|
||||
@ -163,7 +163,7 @@ class TestAsyncWorkflowService:
|
||||
|
||||
mocks["quota_service"].reserve.assert_called_once()
|
||||
quota_charge_mock.commit.assert_called_once()
|
||||
assert session.commit.call_count == 2
|
||||
assert session.commit.call_count == 3
|
||||
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
assert created_log.status == WorkflowTriggerStatus.QUEUED
|
||||
@ -266,7 +266,7 @@ class TestAsyncWorkflowService:
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
assert session.commit.call_count == 2
|
||||
assert session.commit.call_count == 3
|
||||
updated_log = mocks["repo"].update.call_args[0][0]
|
||||
assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED
|
||||
assert "Quota limit reached" in updated_log.error
|
||||
@ -469,7 +469,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123", session=None)
|
||||
workflow_service.get_published_workflow.assert_not_called()
|
||||
|
||||
def test_should_raise_when_specific_workflow_id_not_found(self):
|
||||
@ -497,7 +497,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model)
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None)
|
||||
workflow_service.get_published_workflow_by_id.assert_not_called()
|
||||
|
||||
def test_should_raise_when_default_published_workflow_not_found(self):
|
||||
|
||||
@ -89,7 +89,6 @@ class TestSegmentServiceChildChunks:
|
||||
document = _make_document()
|
||||
segment = _make_segment()
|
||||
existing_a = ChildChunk(
|
||||
id="child-a",
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
document_id="doc-1",
|
||||
@ -100,7 +99,6 @@ class TestSegmentServiceChildChunks:
|
||||
created_by="user-1",
|
||||
)
|
||||
existing_b = ChildChunk(
|
||||
id="child-b",
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
document_id="doc-1",
|
||||
@ -110,7 +108,8 @@ class TestSegmentServiceChildChunks:
|
||||
word_count=9,
|
||||
created_by="user-1",
|
||||
)
|
||||
|
||||
existing_a.id = "child-a"
|
||||
existing_b.id = "child-b"
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.uuid.uuid4", return_value="node-new"),
|
||||
@ -714,7 +713,6 @@ class TestSegmentServiceMutations:
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
|
||||
):
|
||||
segments_query = MagicMock()
|
||||
# execute().all() for segments_info (multi-column)
|
||||
execute_result = MagicMock()
|
||||
execute_result.all.return_value = [
|
||||
|
||||
@ -36,9 +36,7 @@ class TestDatasourceProviderService:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Robust, chainable query mock.
|
||||
q returns itself for .filter_by(), .order_by(), .where() so any
|
||||
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
||||
Mock session with scalar/scalars defaults for current SQLAlchemy access paths.
|
||||
"""
|
||||
with (
|
||||
patch("services.datasource_provider_service.Session") as mock_cls,
|
||||
@ -46,20 +44,6 @@ class TestDatasourceProviderService:
|
||||
):
|
||||
sess = MagicMock(spec=Session)
|
||||
|
||||
q = MagicMock()
|
||||
sess.query.return_value = q
|
||||
|
||||
# Self-returning chain — any method called on q returns q
|
||||
q.filter_by.return_value = q
|
||||
q.order_by.return_value = q
|
||||
q.where.return_value = q
|
||||
|
||||
# Default terminal values (tests override per-case)
|
||||
q.first.return_value = None
|
||||
q.all.return_value = []
|
||||
q.count.return_value = 0
|
||||
q.delete.return_value = 1
|
||||
|
||||
# Default values for select()-style calls (tests override per-case)
|
||||
sess.scalar.return_value = None
|
||||
sess.scalars.return_value.all.return_value = []
|
||||
|
||||
@ -17,23 +17,6 @@ from services.trigger import webhook_service as service_module
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def __init__(self, result: Any) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
return self._result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user