mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 00:33:37 +08:00
Merge branch 'main' into tp
This commit is contained in:
commit
bdecea34a3
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
api-unit:
|
api-unit:
|
||||||
name: API Unit Tests
|
name: API Unit Tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
env:
|
env:
|
||||||
COVERAGE_FILE: coverage-unit
|
COVERAGE_FILE: coverage-unit
|
||||||
defaults:
|
defaults:
|
||||||
@ -62,7 +62,7 @@ jobs:
|
|||||||
|
|
||||||
api-integration:
|
api-integration:
|
||||||
name: API Integration Tests
|
name: API Integration Tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
env:
|
env:
|
||||||
COVERAGE_FILE: coverage-integration
|
COVERAGE_FILE: coverage-integration
|
||||||
STORAGE_TYPE: opendal
|
STORAGE_TYPE: opendal
|
||||||
@ -137,7 +137,7 @@ jobs:
|
|||||||
|
|
||||||
api-coverage:
|
api-coverage:
|
||||||
name: API Coverage
|
name: API Coverage
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
needs:
|
needs:
|
||||||
- api-unit
|
- api-unit
|
||||||
- api-integration
|
- api-integration
|
||||||
|
|||||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -13,7 +13,7 @@ permissions:
|
|||||||
jobs:
|
jobs:
|
||||||
autofix:
|
autofix:
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Complete merge group check
|
- name: Complete merge group check
|
||||||
if: github.event_name == 'merge_group'
|
if: github.event_name == 'merge_group'
|
||||||
|
|||||||
46
.github/workflows/build-push.yml
vendored
46
.github/workflows/build-push.yml
vendored
@ -26,6 +26,9 @@ jobs:
|
|||||||
build:
|
build:
|
||||||
runs-on: ${{ matrix.runs_on }}
|
runs-on: ${{ matrix.runs_on }}
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
id-token: write
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
@ -35,28 +38,28 @@ jobs:
|
|||||||
build_context: "{{defaultContext}}:api"
|
build_context: "{{defaultContext}}:api"
|
||||||
file: "Dockerfile"
|
file: "Dockerfile"
|
||||||
platform: linux/amd64
|
platform: linux/amd64
|
||||||
runs_on: ubuntu-latest
|
runs_on: depot-ubuntu-24.04-4
|
||||||
- service_name: "build-api-arm64"
|
- service_name: "build-api-arm64"
|
||||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||||
artifact_context: "api"
|
artifact_context: "api"
|
||||||
build_context: "{{defaultContext}}:api"
|
build_context: "{{defaultContext}}:api"
|
||||||
file: "Dockerfile"
|
file: "Dockerfile"
|
||||||
platform: linux/arm64
|
platform: linux/arm64
|
||||||
runs_on: ubuntu-24.04-arm
|
runs_on: depot-ubuntu-24.04-4
|
||||||
- service_name: "build-web-amd64"
|
- service_name: "build-web-amd64"
|
||||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||||
artifact_context: "web"
|
artifact_context: "web"
|
||||||
build_context: "{{defaultContext}}"
|
build_context: "{{defaultContext}}"
|
||||||
file: "web/Dockerfile"
|
file: "web/Dockerfile"
|
||||||
platform: linux/amd64
|
platform: linux/amd64
|
||||||
runs_on: ubuntu-latest
|
runs_on: depot-ubuntu-24.04-4
|
||||||
- service_name: "build-web-arm64"
|
- service_name: "build-web-arm64"
|
||||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||||
artifact_context: "web"
|
artifact_context: "web"
|
||||||
build_context: "{{defaultContext}}"
|
build_context: "{{defaultContext}}"
|
||||||
file: "web/Dockerfile"
|
file: "web/Dockerfile"
|
||||||
platform: linux/arm64
|
platform: linux/arm64
|
||||||
runs_on: ubuntu-24.04-arm
|
runs_on: depot-ubuntu-24.04-4
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Prepare
|
- name: Prepare
|
||||||
@ -70,8 +73,8 @@ jobs:
|
|||||||
username: ${{ env.DOCKERHUB_USER }}
|
username: ${{ env.DOCKERHUB_USER }}
|
||||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Depot CLI
|
||||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
uses: depot/setup-action@v1
|
||||||
|
|
||||||
- name: Extract metadata for Docker
|
- name: Extract metadata for Docker
|
||||||
id: meta
|
id: meta
|
||||||
@ -81,16 +84,15 @@ jobs:
|
|||||||
|
|
||||||
- name: Build Docker image
|
- name: Build Docker image
|
||||||
id: build
|
id: build
|
||||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
uses: depot/build-push-action@v1
|
||||||
with:
|
with:
|
||||||
|
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||||
context: ${{ matrix.build_context }}
|
context: ${{ matrix.build_context }}
|
||||||
file: ${{ matrix.file }}
|
file: ${{ matrix.file }}
|
||||||
platforms: ${{ matrix.platform }}
|
platforms: ${{ matrix.platform }}
|
||||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha,scope=${{ matrix.service_name }}
|
|
||||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
env:
|
env:
|
||||||
@ -108,9 +110,33 @@ jobs:
|
|||||||
if-no-files-found: error
|
if-no-files-found: error
|
||||||
retention-days: 1
|
retention-days: 1
|
||||||
|
|
||||||
|
fork-build-validate:
|
||||||
|
if: github.repository != 'langgenius/dify'
|
||||||
|
runs-on: ubuntu-24.04
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- service_name: "validate-api-amd64"
|
||||||
|
build_context: "{{defaultContext}}:api"
|
||||||
|
file: "Dockerfile"
|
||||||
|
- service_name: "validate-web-amd64"
|
||||||
|
build_context: "{{defaultContext}}"
|
||||||
|
file: "web/Dockerfile"
|
||||||
|
steps:
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||||
|
|
||||||
|
- name: Validate Docker image
|
||||||
|
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||||
|
with:
|
||||||
|
push: false
|
||||||
|
context: ${{ matrix.build_context }}
|
||||||
|
file: ${{ matrix.file }}
|
||||||
|
platforms: linux/amd64
|
||||||
|
|
||||||
create-manifest:
|
create-manifest:
|
||||||
needs: build
|
needs: build
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -9,7 +9,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
db-migration-test-postgres:
|
db-migration-test-postgres:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@ -59,7 +59,7 @@ jobs:
|
|||||||
run: uv run --directory api flask upgrade-db
|
run: uv run --directory api flask upgrade-db
|
||||||
|
|
||||||
db-migration-test-mysql:
|
db-migration-test-mysql:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
|
|||||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@ -13,7 +13,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
if: |
|
if: |
|
||||||
github.event.workflow_run.conclusion == 'success' &&
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||||
|
|||||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -10,7 +10,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
if: |
|
if: |
|
||||||
github.event.workflow_run.conclusion == 'success' &&
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||||
|
|||||||
2
.github/workflows/deploy-enterprise.yml
vendored
2
.github/workflows/deploy-enterprise.yml
vendored
@ -13,7 +13,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
if: |
|
if: |
|
||||||
github.event.workflow_run.conclusion == 'success' &&
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||||
|
|||||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@ -10,7 +10,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
if: |
|
if: |
|
||||||
github.event.workflow_run.conclusion == 'success' &&
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||||
|
|||||||
47
.github/workflows/docker-build.yml
vendored
47
.github/workflows/docker-build.yml
vendored
@ -14,40 +14,69 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-docker:
|
build-docker:
|
||||||
|
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||||
runs-on: ${{ matrix.runs_on }}
|
runs-on: ${{ matrix.runs_on }}
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
id-token: write
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- service_name: "api-amd64"
|
- service_name: "api-amd64"
|
||||||
platform: linux/amd64
|
platform: linux/amd64
|
||||||
runs_on: ubuntu-latest
|
runs_on: depot-ubuntu-24.04-4
|
||||||
context: "{{defaultContext}}:api"
|
context: "{{defaultContext}}:api"
|
||||||
file: "Dockerfile"
|
file: "Dockerfile"
|
||||||
- service_name: "api-arm64"
|
- service_name: "api-arm64"
|
||||||
platform: linux/arm64
|
platform: linux/arm64
|
||||||
runs_on: ubuntu-24.04-arm
|
runs_on: depot-ubuntu-24.04-4
|
||||||
context: "{{defaultContext}}:api"
|
context: "{{defaultContext}}:api"
|
||||||
file: "Dockerfile"
|
file: "Dockerfile"
|
||||||
- service_name: "web-amd64"
|
- service_name: "web-amd64"
|
||||||
platform: linux/amd64
|
platform: linux/amd64
|
||||||
runs_on: ubuntu-latest
|
runs_on: depot-ubuntu-24.04-4
|
||||||
context: "{{defaultContext}}"
|
context: "{{defaultContext}}"
|
||||||
file: "web/Dockerfile"
|
file: "web/Dockerfile"
|
||||||
- service_name: "web-arm64"
|
- service_name: "web-arm64"
|
||||||
platform: linux/arm64
|
platform: linux/arm64
|
||||||
runs_on: ubuntu-24.04-arm
|
runs_on: depot-ubuntu-24.04-4
|
||||||
context: "{{defaultContext}}"
|
context: "{{defaultContext}}"
|
||||||
file: "web/Dockerfile"
|
file: "web/Dockerfile"
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Depot CLI
|
||||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
uses: depot/setup-action@v1
|
||||||
|
|
||||||
- name: Build Docker Image
|
- name: Build Docker Image
|
||||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
uses: depot/build-push-action@v1
|
||||||
with:
|
with:
|
||||||
|
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||||
push: false
|
push: false
|
||||||
context: ${{ matrix.context }}
|
context: ${{ matrix.context }}
|
||||||
file: ${{ matrix.file }}
|
file: ${{ matrix.file }}
|
||||||
platforms: ${{ matrix.platform }}
|
platforms: ${{ matrix.platform }}
|
||||||
cache-from: type=gha
|
|
||||||
cache-to: type=gha,mode=max
|
build-docker-fork:
|
||||||
|
if: github.event.pull_request.head.repo.full_name != github.repository
|
||||||
|
runs-on: ubuntu-24.04
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- service_name: "api-amd64"
|
||||||
|
context: "{{defaultContext}}:api"
|
||||||
|
file: "Dockerfile"
|
||||||
|
- service_name: "web-amd64"
|
||||||
|
context: "{{defaultContext}}"
|
||||||
|
file: "web/Dockerfile"
|
||||||
|
steps:
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||||
|
|
||||||
|
- name: Build Docker Image
|
||||||
|
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||||
|
with:
|
||||||
|
push: false
|
||||||
|
context: ${{ matrix.context }}
|
||||||
|
file: ${{ matrix.file }}
|
||||||
|
platforms: linux/amd64
|
||||||
|
|||||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -7,7 +7,7 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||||
with:
|
with:
|
||||||
|
|||||||
24
.github/workflows/main-ci.yml
vendored
24
.github/workflows/main-ci.yml
vendored
@ -23,7 +23,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
pre_job:
|
pre_job:
|
||||||
name: Skip Duplicate Checks
|
name: Skip Duplicate Checks
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
outputs:
|
outputs:
|
||||||
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
|
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
|
||||||
steps:
|
steps:
|
||||||
@ -39,7 +39,7 @@ jobs:
|
|||||||
name: Check Changed Files
|
name: Check Changed Files
|
||||||
needs: pre_job
|
needs: pre_job
|
||||||
if: needs.pre_job.outputs.should_skip != 'true'
|
if: needs.pre_job.outputs.should_skip != 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
outputs:
|
outputs:
|
||||||
api-changed: ${{ steps.changes.outputs.api }}
|
api-changed: ${{ steps.changes.outputs.api }}
|
||||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||||
@ -141,7 +141,7 @@ jobs:
|
|||||||
- pre_job
|
- pre_job
|
||||||
- check-changes
|
- check-changes
|
||||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
|
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Report skipped API tests
|
- name: Report skipped API tests
|
||||||
run: echo "No API-related changes detected; skipping API tests."
|
run: echo "No API-related changes detected; skipping API tests."
|
||||||
@ -154,7 +154,7 @@ jobs:
|
|||||||
- check-changes
|
- check-changes
|
||||||
- api-tests-run
|
- api-tests-run
|
||||||
- api-tests-skip
|
- api-tests-skip
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Finalize API Tests status
|
- name: Finalize API Tests status
|
||||||
env:
|
env:
|
||||||
@ -201,7 +201,7 @@ jobs:
|
|||||||
- pre_job
|
- pre_job
|
||||||
- check-changes
|
- check-changes
|
||||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
|
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Report skipped web tests
|
- name: Report skipped web tests
|
||||||
run: echo "No web-related changes detected; skipping web tests."
|
run: echo "No web-related changes detected; skipping web tests."
|
||||||
@ -214,7 +214,7 @@ jobs:
|
|||||||
- check-changes
|
- check-changes
|
||||||
- web-tests-run
|
- web-tests-run
|
||||||
- web-tests-skip
|
- web-tests-skip
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Finalize Web Tests status
|
- name: Finalize Web Tests status
|
||||||
env:
|
env:
|
||||||
@ -260,7 +260,7 @@ jobs:
|
|||||||
- pre_job
|
- pre_job
|
||||||
- check-changes
|
- check-changes
|
||||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
|
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Report skipped web full-stack e2e
|
- name: Report skipped web full-stack e2e
|
||||||
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
|
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
|
||||||
@ -273,7 +273,7 @@ jobs:
|
|||||||
- check-changes
|
- check-changes
|
||||||
- web-e2e-run
|
- web-e2e-run
|
||||||
- web-e2e-skip
|
- web-e2e-skip
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Finalize Web Full-Stack E2E status
|
- name: Finalize Web Full-Stack E2E status
|
||||||
env:
|
env:
|
||||||
@ -325,7 +325,7 @@ jobs:
|
|||||||
- pre_job
|
- pre_job
|
||||||
- check-changes
|
- check-changes
|
||||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
|
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Report skipped VDB tests
|
- name: Report skipped VDB tests
|
||||||
run: echo "No VDB-related changes detected; skipping VDB tests."
|
run: echo "No VDB-related changes detected; skipping VDB tests."
|
||||||
@ -338,7 +338,7 @@ jobs:
|
|||||||
- check-changes
|
- check-changes
|
||||||
- vdb-tests-run
|
- vdb-tests-run
|
||||||
- vdb-tests-skip
|
- vdb-tests-skip
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Finalize VDB Tests status
|
- name: Finalize VDB Tests status
|
||||||
env:
|
env:
|
||||||
@ -384,7 +384,7 @@ jobs:
|
|||||||
- pre_job
|
- pre_job
|
||||||
- check-changes
|
- check-changes
|
||||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
|
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Report skipped DB migration tests
|
- name: Report skipped DB migration tests
|
||||||
run: echo "No migration-related changes detected; skipping DB migration tests."
|
run: echo "No migration-related changes detected; skipping DB migration tests."
|
||||||
@ -397,7 +397,7 @@ jobs:
|
|||||||
- check-changes
|
- check-changes
|
||||||
- db-migration-test-run
|
- db-migration-test-run
|
||||||
- db-migration-test-skip
|
- db-migration-test-skip
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Finalize DB Migration Test status
|
- name: Finalize DB Migration Test status
|
||||||
env:
|
env:
|
||||||
|
|||||||
2
.github/workflows/pyrefly-diff-comment.yml
vendored
2
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -12,7 +12,7 @@ permissions: {}
|
|||||||
jobs:
|
jobs:
|
||||||
comment:
|
comment:
|
||||||
name: Comment PR with pyrefly diff
|
name: Comment PR with pyrefly diff
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
permissions:
|
permissions:
|
||||||
actions: read
|
actions: read
|
||||||
contents: read
|
contents: read
|
||||||
|
|||||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pyrefly-diff:
|
pyrefly-diff:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
issues: write
|
issues: write
|
||||||
|
|||||||
@ -12,7 +12,7 @@ permissions: {}
|
|||||||
jobs:
|
jobs:
|
||||||
comment:
|
comment:
|
||||||
name: Comment PR with type coverage
|
name: Comment PR with type coverage
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
permissions:
|
permissions:
|
||||||
actions: read
|
actions: read
|
||||||
contents: read
|
contents: read
|
||||||
|
|||||||
2
.github/workflows/pyrefly-type-coverage.yml
vendored
2
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pyrefly-type-coverage:
|
pyrefly-type-coverage:
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
issues: write
|
issues: write
|
||||||
|
|||||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
name: Validate PR title
|
name: Validate PR title
|
||||||
permissions:
|
permissions:
|
||||||
pull-requests: read
|
pull-requests: read
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Complete merge group check
|
- name: Complete merge group check
|
||||||
if: github.event_name == 'merge_group'
|
if: github.event_name == 'merge_group'
|
||||||
|
|||||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -12,7 +12,7 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
stale:
|
stale:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
permissions:
|
permissions:
|
||||||
issues: write
|
issues: write
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|||||||
6
.github/workflows/style.yml
vendored
6
.github/workflows/style.yml
vendored
@ -15,7 +15,7 @@ permissions:
|
|||||||
jobs:
|
jobs:
|
||||||
python-style:
|
python-style:
|
||||||
name: Python Style
|
name: Python Style
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@ -57,7 +57,7 @@ jobs:
|
|||||||
|
|
||||||
web-style:
|
web-style:
|
||||||
name: Web Style
|
name: Web Style
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
@ -131,7 +131,7 @@ jobs:
|
|||||||
|
|
||||||
superlinter:
|
superlinter:
|
||||||
name: SuperLinter
|
name: SuperLinter
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
|
|||||||
2
.github/workflows/tool-test-sdks.yaml
vendored
2
.github/workflows/tool-test-sdks.yaml
vendored
@ -18,7 +18,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
name: unit test for Node.js SDK
|
name: unit test for Node.js SDK
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
|
|||||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -35,7 +35,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
translate:
|
translate:
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@ -158,7 +158,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run Claude Code for Translation Sync
|
- name: Run Claude Code for Translation Sync
|
||||||
if: steps.context.outputs.CHANGED_FILES != ''
|
if: steps.context.outputs.CHANGED_FILES != ''
|
||||||
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
|
uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
|
||||||
with:
|
with:
|
||||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
trigger:
|
trigger:
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
timeout-minutes: 5
|
timeout-minutes: 5
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
test:
|
test:
|
||||||
name: Full VDB Tests
|
name: Full VDB Tests
|
||||||
if: github.repository == 'langgenius/dify'
|
if: github.repository == 'langgenius/dify'
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
|
|||||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
name: VDB Smoke Tests
|
name: VDB Smoke Tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
|
|||||||
2
.github/workflows/web-e2e.yml
vendored
2
.github/workflows/web-e2e.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
name: Web Full-Stack E2E
|
name: Web Full-Stack E2E
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
env:
|
env:
|
||||||
VITEST_COVERAGE_SCOPE: app-components
|
VITEST_COVERAGE_SCOPE: app-components
|
||||||
strategy:
|
strategy:
|
||||||
@ -54,7 +54,7 @@ jobs:
|
|||||||
name: Merge Test Reports
|
name: Merge Test Reports
|
||||||
if: ${{ !cancelled() }}
|
if: ${{ !cancelled() }}
|
||||||
needs: [test]
|
needs: [test]
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
defaults:
|
defaults:
|
||||||
@ -92,7 +92,7 @@ jobs:
|
|||||||
|
|
||||||
dify-ui-test:
|
dify-ui-test:
|
||||||
name: dify-ui Tests
|
name: dify-ui Tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: depot-ubuntu-24.04
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
defaults:
|
defaults:
|
||||||
|
|||||||
@ -147,7 +147,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
|
|||||||
|
|
||||||
### Deployment with Kubernetes
|
### Deployment with Kubernetes
|
||||||
|
|
||||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||||
|
|
||||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||||
|
|||||||
@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
|||||||
MARKETPLACE_ENABLED=true
|
MARKETPLACE_ENABLED=true
|
||||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||||
|
|
||||||
|
# Creators Platform configuration
|
||||||
|
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||||
|
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||||
|
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||||
|
|
||||||
# Endpoint configuration
|
# Endpoint configuration
|
||||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from configs import dify_config
|
|||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
from core.tools.utils.system_encryption import encrypt_system_params
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Tenant
|
from models import Tenant
|
||||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
|||||||
|
|
||||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
|||||||
|
|
||||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||||
|
|||||||
@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreatorsPlatformConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for Creators Platform integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||||
|
description="Enable or disable Creators Platform features",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||||
|
description="Creators Platform API URL",
|
||||||
|
default=HttpUrl("https://creators.dify.ai"),
|
||||||
|
)
|
||||||
|
|
||||||
|
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||||
|
description="OAuth client ID for Creators Platform integration",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EndpointConfig(BaseSettings):
|
class EndpointConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration for various application endpoints and URLs
|
Configuration for various application endpoints and URLs
|
||||||
@ -1379,6 +1400,7 @@ class FeatureConfig(
|
|||||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||||
BillingConfig,
|
BillingConfig,
|
||||||
CodeExecutionSandboxConfig,
|
CodeExecutionSandboxConfig,
|
||||||
|
CreatorsPlatformConfig,
|
||||||
TriggerConfig,
|
TriggerConfig,
|
||||||
AsyncWorkflowConfig,
|
AsyncWorkflowConfig,
|
||||||
PluginConfig,
|
PluginConfig,
|
||||||
|
|||||||
6
api/controllers/common/human_input.py
Normal file
6
api/controllers/common/human_input.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel, JsonValue
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInputFormSubmitPayload(BaseModel):
|
||||||
|
inputs: dict[str, JsonValue]
|
||||||
|
action: str
|
||||||
@ -692,6 +692,32 @@ class AppExportApi(Resource):
|
|||||||
return payload.model_dump(mode="json")
|
return payload.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||||
|
class AppPublishToCreatorsPlatformApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=None)
|
||||||
|
@edit_permission_required
|
||||||
|
def post(self, app_model):
|
||||||
|
"""Publish app to Creators Platform"""
|
||||||
|
from configs import dify_config
|
||||||
|
from core.helper.creators import get_redirect_url, upload_dsl
|
||||||
|
|
||||||
|
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||||
|
return {"error": "Creators Platform features are not enabled"}, 403
|
||||||
|
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||||
|
dsl_bytes = dsl_content.encode("utf-8")
|
||||||
|
|
||||||
|
claim_code = upload_dsl(dsl_bytes)
|
||||||
|
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||||
|
|
||||||
|
return {"redirect_url": redirect_url}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||||
class AppNameApi(Resource):
|
class AppNameApi(Resource):
|
||||||
@console_ns.doc("check_app_name")
|
@console_ns.doc("check_app_name")
|
||||||
|
|||||||
@ -8,10 +8,10 @@ from collections.abc import Generator
|
|||||||
|
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||||
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
|
|||||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||||
from core.app.apps.message_generator import MessageGenerator
|
from core.app.apps.message_generator import MessageGenerator
|
||||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
|
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App
|
from models import App
|
||||||
from models.enums import CreatorUserRole
|
from models.enums import CreatorUserRole
|
||||||
from models.human_input import RecipientType
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from models.workflow import WorkflowRun
|
from models.workflow import WorkflowRun
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormSubmitPayload(BaseModel):
|
|
||||||
inputs: dict
|
|
||||||
action: str
|
|
||||||
|
|
||||||
|
|
||||||
def _jsonify_form_definition(form: Form) -> Response:
|
def _jsonify_form_definition(form: Form) -> Response:
|
||||||
payload = form.get_definition().model_dump()
|
payload = form.get_definition().model_dump()
|
||||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||||
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
|
|||||||
if form.tenant_id != current_tenant_id:
|
if form.tenant_id != current_tenant_id:
|
||||||
raise NotFoundError("App not found")
|
raise NotFoundError("App not found")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_console_recipient_type(form: Form) -> None:
|
||||||
|
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
|
||||||
|
raise NotFoundError("form not found")
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
|
|||||||
raise NotFoundError(f"form not found, token={form_token}")
|
raise NotFoundError(f"form not found, token={form_token}")
|
||||||
|
|
||||||
self._ensure_console_access(form)
|
self._ensure_console_access(form)
|
||||||
|
self._ensure_console_recipient_type(form)
|
||||||
recipient_type = form.recipient_type
|
recipient_type = form.recipient_type
|
||||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
|
||||||
raise NotFoundError(f"form not found, token={form_token}")
|
|
||||||
# The type checker is not smart enought to validate the following invariant.
|
# The type checker is not smart enought to validate the following invariant.
|
||||||
# So we need to assert it manually.
|
# So we need to assert it manually.
|
||||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||||
|
|||||||
@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
|
|||||||
type: TagType = Field(description="Tag type")
|
type: TagType = Field(description="Tag type")
|
||||||
|
|
||||||
|
|
||||||
|
class TagBindingItemDeletePayload(BaseModel):
|
||||||
|
target_id: str = Field(description="Target ID to unbind tag from")
|
||||||
|
type: TagType = Field(description="Tag type")
|
||||||
|
|
||||||
|
|
||||||
class TagListQueryParam(BaseModel):
|
class TagListQueryParam(BaseModel):
|
||||||
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
||||||
keyword: str | None = Field(None, description="Search keyword")
|
keyword: str | None = Field(None, description="Search keyword")
|
||||||
@ -70,6 +75,7 @@ register_schema_models(
|
|||||||
TagBasePayload,
|
TagBasePayload,
|
||||||
TagBindingPayload,
|
TagBindingPayload,
|
||||||
TagBindingRemovePayload,
|
TagBindingRemovePayload,
|
||||||
|
TagBindingItemDeletePayload,
|
||||||
TagListQueryParam,
|
TagListQueryParam,
|
||||||
TagResponse,
|
TagResponse,
|
||||||
)
|
)
|
||||||
@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
|
|||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/tag-bindings/create")
|
def _require_tag_binding_edit_permission() -> None:
|
||||||
class TagBindingCreateApi(Resource):
|
"""
|
||||||
|
Ensure the current account can edit tag bindings.
|
||||||
|
|
||||||
|
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||||
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||||
|
_require_tag_binding_edit_permission()
|
||||||
|
|
||||||
|
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||||
|
TagService.save_tag_binding(
|
||||||
|
TagBindingCreatePayload(
|
||||||
|
tag_ids=payload.tag_ids,
|
||||||
|
target_id=payload.target_id,
|
||||||
|
type=payload.type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_tag_binding() -> tuple[dict[str, str], int]:
|
||||||
|
_require_tag_binding_edit_permission()
|
||||||
|
|
||||||
|
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||||
|
TagService.delete_tag_binding(
|
||||||
|
TagBindingDeletePayload(
|
||||||
|
tag_id=payload.tag_id,
|
||||||
|
target_id=payload.target_id,
|
||||||
|
type=payload.type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/tag-bindings")
|
||||||
|
class TagBindingCollectionApi(Resource):
|
||||||
|
"""Canonical collection resource for tag binding creation."""
|
||||||
|
|
||||||
|
@console_ns.doc("create_tag_binding")
|
||||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
current_user, _ = current_account_with_tenant()
|
return _create_tag_bindings()
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
|
||||||
TagService.save_tag_binding(
|
@console_ns.route("/tag-bindings/<uuid:id>")
|
||||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
class TagBindingItemApi(Resource):
|
||||||
|
"""Canonical item resource for tag binding deletion."""
|
||||||
|
|
||||||
|
@console_ns.doc("delete_tag_binding")
|
||||||
|
@console_ns.doc(params={"id": "Tag ID"})
|
||||||
|
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, id):
|
||||||
|
_require_tag_binding_edit_permission()
|
||||||
|
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
|
||||||
|
TagService.delete_tag_binding(
|
||||||
|
TagBindingDeletePayload(
|
||||||
|
tag_id=str(id),
|
||||||
|
target_id=payload.target_id,
|
||||||
|
type=payload.type,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/tag-bindings/create")
|
||||||
|
class DeprecatedTagBindingCreateApi(Resource):
|
||||||
|
"""Deprecated verb-based alias for tag binding creation."""
|
||||||
|
|
||||||
|
@console_ns.doc("create_tag_binding_deprecated")
|
||||||
|
@console_ns.doc(deprecated=True)
|
||||||
|
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
|
||||||
|
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
return _create_tag_bindings()
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/tag-bindings/remove")
|
@console_ns.route("/tag-bindings/remove")
|
||||||
class TagBindingDeleteApi(Resource):
|
class DeprecatedTagBindingRemoveApi(Resource):
|
||||||
|
"""Deprecated verb-based alias for tag binding deletion."""
|
||||||
|
|
||||||
|
@console_ns.doc("delete_tag_binding_deprecated")
|
||||||
|
@console_ns.doc(deprecated=True)
|
||||||
|
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
|
||||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
current_user, _ = current_account_with_tenant()
|
return _remove_tag_binding()
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
|
||||||
TagService.delete_tag_binding(
|
|
||||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
|
||||||
|
|||||||
@ -23,9 +23,11 @@ from .app import (
|
|||||||
conversation,
|
conversation,
|
||||||
file,
|
file,
|
||||||
file_preview,
|
file_preview,
|
||||||
|
human_input_form,
|
||||||
message,
|
message,
|
||||||
site,
|
site,
|
||||||
workflow,
|
workflow,
|
||||||
|
workflow_events,
|
||||||
)
|
)
|
||||||
from .dataset import (
|
from .dataset import (
|
||||||
dataset,
|
dataset,
|
||||||
@ -50,6 +52,7 @@ __all__ = [
|
|||||||
"file",
|
"file",
|
||||||
"file_preview",
|
"file_preview",
|
||||||
"hit_testing",
|
"hit_testing",
|
||||||
|
"human_input_form",
|
||||||
"index",
|
"index",
|
||||||
"message",
|
"message",
|
||||||
"metadata",
|
"metadata",
|
||||||
@ -58,6 +61,7 @@ __all__ = [
|
|||||||
"segment",
|
"segment",
|
||||||
"site",
|
"site",
|
||||||
"workflow",
|
"workflow",
|
||||||
|
"workflow_events",
|
||||||
]
|
]
|
||||||
|
|
||||||
api.add_namespace(service_api_ns)
|
api.add_namespace(service_api_ns)
|
||||||
|
|||||||
137
api/controllers/service_api/app/human_input_form.py
Normal file
137
api/controllers/service_api/app/human_input_form.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
Service API human input form endpoints.
|
||||||
|
|
||||||
|
This module exposes app-token authenticated APIs for fetching and submitting
|
||||||
|
paused human input forms in workflow/chatflow runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
from flask_restx import Resource
|
||||||
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
|
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
|
from controllers.service_api import service_api_ns
|
||||||
|
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||||
|
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import App, EndUser
|
||||||
|
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||||
|
|
||||||
|
|
||||||
|
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||||
|
result: dict[str, str] = {}
|
||||||
|
for key, value in values.items():
|
||||||
|
if value is None:
|
||||||
|
result[key] = ""
|
||||||
|
elif isinstance(value, (dict, list)):
|
||||||
|
result[key] = json.dumps(value, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
result[key] = str(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _to_timestamp(value: datetime) -> int:
|
||||||
|
return int(value.timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
def _jsonify_form_definition(form: Form) -> Response:
|
||||||
|
definition_payload = form.get_definition().model_dump()
|
||||||
|
payload = {
|
||||||
|
"form_content": definition_payload["rendered_content"],
|
||||||
|
"inputs": definition_payload["inputs"],
|
||||||
|
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||||
|
"user_actions": definition_payload["user_actions"],
|
||||||
|
"expiration_time": _to_timestamp(form.expiration_time),
|
||||||
|
}
|
||||||
|
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
|
||||||
|
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
|
||||||
|
# Keep app-token callers scoped to the public web-form surface; internal HITL
|
||||||
|
# routes must continue to flow through console-only authentication.
|
||||||
|
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
|
||||||
|
@service_api_ns.route("/form/human_input/<string:form_token>")
|
||||||
|
class WorkflowHumanInputFormApi(Resource):
|
||||||
|
@service_api_ns.doc("get_human_input_form")
|
||||||
|
@service_api_ns.doc(description="Get a paused human input form by token")
|
||||||
|
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||||
|
@service_api_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Form retrieved successfully",
|
||||||
|
401: "Unauthorized - invalid API token",
|
||||||
|
404: "Form not found",
|
||||||
|
412: "Form already submitted or expired",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@validate_app_token
|
||||||
|
def get(self, app_model: App, form_token: str):
|
||||||
|
service = HumanInputService(db.engine)
|
||||||
|
form = service.get_form_by_token(form_token)
|
||||||
|
if form is None:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
_ensure_form_belongs_to_app(form, app_model)
|
||||||
|
_ensure_form_is_allowed_for_service_api(form)
|
||||||
|
service.ensure_form_active(form)
|
||||||
|
return _jsonify_form_definition(form)
|
||||||
|
|
||||||
|
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||||
|
@service_api_ns.doc("submit_human_input_form")
|
||||||
|
@service_api_ns.doc(description="Submit a paused human input form by token")
|
||||||
|
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||||
|
@service_api_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Form submitted successfully",
|
||||||
|
400: "Bad request - invalid submission data",
|
||||||
|
401: "Unauthorized - invalid API token",
|
||||||
|
404: "Form not found",
|
||||||
|
412: "Form already submitted or expired",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||||
|
def post(self, app_model: App, end_user: EndUser, form_token: str):
|
||||||
|
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
|
||||||
|
|
||||||
|
service = HumanInputService(db.engine)
|
||||||
|
form = service.get_form_by_token(form_token)
|
||||||
|
if form is None:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
_ensure_form_belongs_to_app(form, app_model)
|
||||||
|
_ensure_form_is_allowed_for_service_api(form)
|
||||||
|
|
||||||
|
recipient_type = form.recipient_type
|
||||||
|
if recipient_type is None:
|
||||||
|
logger.warning("Recipient type is None for form, form_id=%s", form.id)
|
||||||
|
raise BadRequest("Form recipient type is invalid")
|
||||||
|
|
||||||
|
try:
|
||||||
|
service.submit_form_by_token(
|
||||||
|
recipient_type=recipient_type,
|
||||||
|
form_token=form_token,
|
||||||
|
selected_action_id=payload.action,
|
||||||
|
form_data=payload.inputs,
|
||||||
|
submission_end_user_id=end_user.id,
|
||||||
|
)
|
||||||
|
except FormNotFoundError:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
return {}, 200
|
||||||
142
api/controllers/service_api/app/workflow_events.py
Normal file
142
api/controllers/service_api/app/workflow_events.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
"""
|
||||||
|
Service API workflow resume event stream endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from flask import Response, request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.service_api import service_api_ns
|
||||||
|
from controllers.service_api.app.error import NotWorkflowAppError
|
||||||
|
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||||
|
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||||
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
|
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||||
|
from core.app.apps.message_generator import MessageGenerator
|
||||||
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
|
from core.app.entities.task_entities import StreamEvent
|
||||||
|
from core.workflow.human_input_policy import HumanInputSurface
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import App, AppMode, EndUser
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||||
|
|
||||||
|
|
||||||
|
@service_api_ns.route("/workflow/<string:task_id>/events")
|
||||||
|
class WorkflowEventsApi(Resource):
|
||||||
|
"""Service API for getting workflow execution events after resume."""
|
||||||
|
|
||||||
|
@service_api_ns.doc("get_workflow_events")
|
||||||
|
@service_api_ns.doc(description="Get workflow execution events stream after resume")
|
||||||
|
@service_api_ns.doc(
|
||||||
|
params={
|
||||||
|
"task_id": "Workflow run ID",
|
||||||
|
"user": "End user identifier (query param)",
|
||||||
|
"include_state_snapshot": (
|
||||||
|
"Whether to replay from persisted state snapshot, "
|
||||||
|
'specify `"true"` to include a status snapshot of executed nodes'
|
||||||
|
),
|
||||||
|
"continue_on_pause": (
|
||||||
|
"Whether to keep the stream open across workflow_paused events,"
|
||||||
|
'specify `"true"` to keep the stream open for `workflow_paused` events.'
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@service_api_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "SSE event stream",
|
||||||
|
401: "Unauthorized - invalid API token",
|
||||||
|
404: "Workflow run not found",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||||
|
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
|
session_maker = sessionmaker(db.engine)
|
||||||
|
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
run_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if workflow_run is None:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
if workflow_run.app_id != app_model.id:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
if workflow_run.created_by != end_user.id:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
workflow_run_entity = workflow_run
|
||||||
|
|
||||||
|
if workflow_run_entity.finished_at is not None:
|
||||||
|
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||||
|
task_id=workflow_run_entity.id,
|
||||||
|
workflow_run=workflow_run_entity,
|
||||||
|
creator_user=end_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = response.model_dump(mode="json")
|
||||||
|
payload["event"] = response.event.value
|
||||||
|
|
||||||
|
def _generate_finished_events() -> Generator[str, None, None]:
|
||||||
|
yield f"data: {json.dumps(payload)}\n\n"
|
||||||
|
|
||||||
|
event_generator = _generate_finished_events
|
||||||
|
else:
|
||||||
|
msg_generator = MessageGenerator()
|
||||||
|
generator: BaseAppGenerator
|
||||||
|
if app_mode == AppMode.ADVANCED_CHAT:
|
||||||
|
generator = AdvancedChatAppGenerator()
|
||||||
|
elif app_mode == AppMode.WORKFLOW:
|
||||||
|
generator = WorkflowAppGenerator()
|
||||||
|
else:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
|
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||||
|
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||||
|
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||||
|
|
||||||
|
def _generate_stream_events():
|
||||||
|
if include_state_snapshot:
|
||||||
|
return generator.convert_to_event_stream(
|
||||||
|
build_workflow_event_stream(
|
||||||
|
app_mode=app_mode,
|
||||||
|
workflow_run=workflow_run_entity,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
session_maker=session_maker,
|
||||||
|
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||||
|
close_on_pause=not continue_on_pause,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return generator.convert_to_event_stream(
|
||||||
|
msg_generator.retrieve_events(
|
||||||
|
app_mode,
|
||||||
|
workflow_run_entity.id,
|
||||||
|
terminal_events=terminal_events,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
event_generator = _generate_stream_events
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
event_generator(),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||||
from controllers.web.site import serialize_app_site_payload
|
from controllers.web.site import serialize_app_site_payload
|
||||||
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormSubmitPayload(BaseModel):
|
|
||||||
inputs: dict
|
|
||||||
action: str
|
|
||||||
|
|
||||||
|
|
||||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||||
prefix="web_form_submit_rate_limit",
|
prefix="web_form_submit_rate_limit",
|
||||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||||
|
|||||||
@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
|||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import (
|
||||||
|
AdvancedChatPausedBlockingResponse,
|
||||||
|
ChatbotAppBlockingResponse,
|
||||||
|
ChatbotAppStreamResponse,
|
||||||
|
)
|
||||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
user: Account | EndUser,
|
user: Account | EndUser,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
) -> (
|
||||||
|
ChatbotAppBlockingResponse
|
||||||
|
| AdvancedChatPausedBlockingResponse
|
||||||
|
| Generator[ChatbotAppStreamResponse, None, None]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
|||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
AppBlockingResponse,
|
AdvancedChatPausedBlockingResponse,
|
||||||
AppStreamResponse,
|
AppStreamResponse,
|
||||||
ChatbotAppBlockingResponse,
|
ChatbotAppBlockingResponse,
|
||||||
ChatbotAppStreamResponse,
|
ChatbotAppStreamResponse,
|
||||||
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
|
|||||||
NodeFinishStreamResponse,
|
NodeFinishStreamResponse,
|
||||||
NodeStartStreamResponse,
|
NodeStartStreamResponse,
|
||||||
PingStreamResponse,
|
PingStreamResponse,
|
||||||
|
StreamEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
class AdvancedChatAppGenerateResponseConverter(
|
||||||
_blocking_response_type = ChatbotAppBlockingResponse
|
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||||
|
):
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
def convert_blocking_full_response(
|
||||||
|
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
|
||||||
|
paused_data = blocking_response.data.model_dump(mode="json")
|
||||||
|
return {
|
||||||
|
"event": StreamEvent.WORKFLOW_PAUSED.value,
|
||||||
|
"task_id": blocking_response.task_id,
|
||||||
|
"id": blocking_response.data.id,
|
||||||
|
"message_id": blocking_response.data.message_id,
|
||||||
|
"conversation_id": blocking_response.data.conversation_id,
|
||||||
|
"mode": blocking_response.data.mode,
|
||||||
|
"answer": blocking_response.data.answer,
|
||||||
|
"metadata": blocking_response.data.metadata,
|
||||||
|
"created_at": blocking_response.data.created_at,
|
||||||
|
"workflow_run_id": blocking_response.data.workflow_run_id,
|
||||||
|
"data": paused_data,
|
||||||
|
}
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"event": "message",
|
"event": StreamEvent.MESSAGE.value,
|
||||||
"task_id": blocking_response.task_id,
|
"task_id": blocking_response.task_id,
|
||||||
"id": blocking_response.data.id,
|
"id": blocking_response.data.id,
|
||||||
"message_id": blocking_response.data.message_id,
|
"message_id": blocking_response.data.message_id,
|
||||||
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
def convert_blocking_simple_response(
|
||||||
|
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
response = cls.convert_blocking_full_response(blocking_response)
|
response = cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
metadata = response.get("metadata", {})
|
metadata = response.get("metadata", {})
|
||||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
if isinstance(metadata, dict):
|
||||||
|
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
|
|||||||
WorkflowQueueMessage,
|
WorkflowQueueMessage,
|
||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
|
AdvancedChatPausedBlockingResponse,
|
||||||
ChatbotAppBlockingResponse,
|
ChatbotAppBlockingResponse,
|
||||||
ChatbotAppStreamResponse,
|
ChatbotAppStreamResponse,
|
||||||
ErrorStreamResponse,
|
ErrorStreamResponse,
|
||||||
|
HumanInputRequiredPauseReasonPayload,
|
||||||
|
HumanInputRequiredResponse,
|
||||||
MessageAudioEndStreamResponse,
|
MessageAudioEndStreamResponse,
|
||||||
MessageAudioStreamResponse,
|
MessageAudioStreamResponse,
|
||||||
MessageEndStreamResponse,
|
MessageEndStreamResponse,
|
||||||
PingStreamResponse,
|
PingStreamResponse,
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
|
WorkflowPauseStreamResponse,
|
||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
if message.status == MessageStatus.PAUSED and message.answer:
|
if message.status == MessageStatus.PAUSED and message.answer:
|
||||||
self._task_state.answer = message.answer
|
self._task_state.answer = message.answer
|
||||||
|
|
||||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
def process(
|
||||||
|
self,
|
||||||
|
) -> Union[
|
||||||
|
ChatbotAppBlockingResponse,
|
||||||
|
AdvancedChatPausedBlockingResponse,
|
||||||
|
Generator[ChatbotAppStreamResponse, None, None],
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Process generate task pipeline.
|
Process generate task pipeline.
|
||||||
:return:
|
:return:
|
||||||
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
|
|
||||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
def _to_blocking_response(
|
||||||
|
self, generator: Generator[StreamResponse, None, None]
|
||||||
|
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
|
||||||
"""
|
"""
|
||||||
Process blocking response.
|
Process blocking response.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||||
for stream_response in generator:
|
for stream_response in generator:
|
||||||
if isinstance(stream_response, ErrorStreamResponse):
|
if isinstance(stream_response, ErrorStreamResponse):
|
||||||
raise stream_response.err
|
raise stream_response.err
|
||||||
|
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||||
|
human_input_responses.append(stream_response)
|
||||||
|
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||||
|
return AdvancedChatPausedBlockingResponse(
|
||||||
|
task_id=stream_response.task_id,
|
||||||
|
data=AdvancedChatPausedBlockingResponse.Data(
|
||||||
|
id=self._message_id,
|
||||||
|
mode=self._conversation_mode,
|
||||||
|
conversation_id=self._conversation_id,
|
||||||
|
message_id=self._message_id,
|
||||||
|
workflow_run_id=stream_response.data.workflow_run_id,
|
||||||
|
answer=self._task_state.answer,
|
||||||
|
metadata=self._message_end_to_stream_response().metadata,
|
||||||
|
created_at=self._message_created_at,
|
||||||
|
paused_nodes=stream_response.data.paused_nodes,
|
||||||
|
reasons=stream_response.data.reasons,
|
||||||
|
status=stream_response.data.status,
|
||||||
|
elapsed_time=stream_response.data.elapsed_time,
|
||||||
|
total_tokens=stream_response.data.total_tokens,
|
||||||
|
total_steps=stream_response.data.total_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||||
extras = {}
|
extras = {}
|
||||||
if stream_response.metadata:
|
if stream_response.metadata:
|
||||||
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if human_input_responses:
|
||||||
|
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||||
|
|
||||||
raise ValueError("queue listening stopped unexpectedly.")
|
raise ValueError("queue listening stopped unexpectedly.")
|
||||||
|
|
||||||
|
def _build_paused_blocking_response_from_human_input(
|
||||||
|
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||||
|
) -> AdvancedChatPausedBlockingResponse:
|
||||||
|
runtime_state = self._resolve_graph_runtime_state()
|
||||||
|
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||||
|
reasons = [
|
||||||
|
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||||
|
for response in human_input_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
return AdvancedChatPausedBlockingResponse(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
data=AdvancedChatPausedBlockingResponse.Data(
|
||||||
|
id=self._message_id,
|
||||||
|
mode=self._conversation_mode,
|
||||||
|
conversation_id=self._conversation_id,
|
||||||
|
message_id=self._message_id,
|
||||||
|
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||||
|
answer=self._task_state.answer,
|
||||||
|
metadata=self._message_end_to_stream_response().metadata,
|
||||||
|
created_at=self._message_created_at,
|
||||||
|
paused_nodes=paused_nodes,
|
||||||
|
reasons=reasons,
|
||||||
|
status=WorkflowExecutionStatus.PAUSED,
|
||||||
|
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||||
|
total_tokens=runtime_state.total_tokens,
|
||||||
|
total_steps=runtime_state.node_run_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _to_stream_response(
|
def _to_stream_response(
|
||||||
self, generator: Generator[StreamResponse, None, None]
|
self, generator: Generator[StreamResponse, None, None]
|
||||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from pydantic import JsonValue
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
AppStreamResponse,
|
AppStreamResponse,
|
||||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||||
_blocking_response_type = ChatbotAppBlockingResponse
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, JsonValue] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"conversation_id": chunk.conversation_id,
|
"conversation_id": chunk.conversation_id,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, JsonValue] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"conversation_id": chunk.conversation_id,
|
"conversation_id": chunk.conversation_id,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Union
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
|
from pydantic import JsonValue
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||||
@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AppGenerateResponseConverter(ABC):
|
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||||
_blocking_response_type: type[AppBlockingResponse]
|
@classmethod
|
||||||
|
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
|
||||||
|
return cast(TBlockingResponse, response)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(
|
def convert(
|
||||||
@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||||
if isinstance(response, AppBlockingResponse):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_full_response(response)
|
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||||
@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
return _generate_full_response()
|
return _generate_full_response()
|
||||||
else:
|
else:
|
||||||
if isinstance(response, AppBlockingResponse):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_simple_response(response)
|
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||||
@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
|
||||||
"""
|
"""
|
||||||
Error to stream response.
|
Error to stream response.
|
||||||
:param e: exception
|
:param e: exception
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
|
||||||
ValueError: {"code": "invalid_param", "status": 400},
|
ValueError: {"code": "invalid_param", "status": 400},
|
||||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||||
QuotaExceededError: {
|
QuotaExceededError: {
|
||||||
@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Determine the response based on the type of exception
|
# Determine the response based on the type of exception
|
||||||
data: dict[str, Any] | None = None
|
data: dict[str, JsonValue] | None = None
|
||||||
for k, v in error_responses.items():
|
for k, v in error_responses.items():
|
||||||
if isinstance(e, k):
|
if isinstance(e, k):
|
||||||
data = v
|
data = v
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from pydantic import JsonValue
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
AppStreamResponse,
|
AppStreamResponse,
|
||||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||||
_blocking_response_type = ChatbotAppBlockingResponse
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, JsonValue] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"conversation_id": chunk.conversation_id,
|
"conversation_id": chunk.conversation_id,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, JsonValue] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"conversation_id": chunk.conversation_id,
|
"conversation_id": chunk.conversation_id,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
|
|||||||
@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
|
|||||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||||
from core.trigger.trigger_manager import TriggerManager
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||||
|
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
|
|||||||
except (TypeError, json.JSONDecodeError):
|
except (TypeError, json.JSONDecodeError):
|
||||||
definition_payload = {}
|
definition_payload = {}
|
||||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||||
|
human_input_form_ids,
|
||||||
|
session=session,
|
||||||
|
surface=(
|
||||||
|
HumanInputSurface.SERVICE_API
|
||||||
|
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||||
|
# otherwise clients see schema drift after resume.
|
||||||
|
pause_reasons = enrich_human_input_pause_reasons(
|
||||||
|
pause_reasons,
|
||||||
|
form_tokens_by_form_id=form_token_by_form_id,
|
||||||
|
expiration_times_by_form_id={
|
||||||
|
form_id: int(expiration_time.timestamp())
|
||||||
|
for form_id, expiration_time in expiration_times_by_form_id.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
responses: list[StreamResponse] = []
|
responses: list[StreamResponse] = []
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from pydantic import JsonValue
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
AppStreamResponse,
|
AppStreamResponse,
|
||||||
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||||
_blocking_response_type = CompletionAppBlockingResponse
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
response = {
|
response: dict[str, Any] = {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"task_id": blocking_response.task_id,
|
"task_id": blocking_response.task_id,
|
||||||
"id": blocking_response.data.id,
|
"id": blocking_response.data.id,
|
||||||
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, JsonValue] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
"created_at": chunk.created_at,
|
"created_at": chunk.created_at,
|
||||||
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
yield "ping"
|
yield "ping"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_chunk = {
|
response_chunk: dict[str, JsonValue] = {
|
||||||
"event": sub_stream_response.event.value,
|
"event": sub_stream_response.event.value,
|
||||||
"message_id": chunk.message_id,
|
"message_id": chunk.message_id,
|
||||||
"created_at": chunk.created_at,
|
"created_at": chunk.created_at,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Callable, Generator, Mapping
|
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||||
|
|
||||||
from core.app.apps.streaming_utils import stream_topic_events
|
from core.app.apps.streaming_utils import stream_topic_events
|
||||||
|
from core.app.entities.task_entities import StreamEvent
|
||||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||||
from libs.broadcast_channel.channel import Topic
|
from libs.broadcast_channel.channel import Topic
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
@ -26,6 +27,7 @@ class MessageGenerator:
|
|||||||
idle_timeout=300,
|
idle_timeout=300,
|
||||||
ping_interval: float = 10.0,
|
ping_interval: float = 10.0,
|
||||||
on_subscribe: Callable[[], None] | None = None,
|
on_subscribe: Callable[[], None] | None = None,
|
||||||
|
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||||
) -> Generator[Mapping | str, None, None]:
|
) -> Generator[Mapping | str, None, None]:
|
||||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||||
return stream_topic_events(
|
return stream_topic_events(
|
||||||
@ -33,4 +35,5 @@ class MessageGenerator:
|
|||||||
idle_timeout=idle_timeout,
|
idle_timeout=idle_timeout,
|
||||||
ping_interval=ping_interval,
|
ping_interval=ping_interval,
|
||||||
on_subscribe=on_subscribe,
|
on_subscribe=on_subscribe,
|
||||||
|
terminal_events=terminal_events,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||||
_blocking_response_type = WorkflowAppBlockingResponse
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return dict(blocking_response.model_dump())
|
return dict(blocking_response.model_dump())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
|
|||||||
@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
|||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import (
|
||||||
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppPausedBlockingResponse,
|
||||||
|
WorkflowAppStreamResponse,
|
||||||
|
)
|
||||||
from core.datasource.entities.datasource_entities import (
|
from core.datasource.entities.datasource_entities import (
|
||||||
DatasourceProviderType,
|
DatasourceProviderType,
|
||||||
OnlineDriveBrowseFilesRequest,
|
OnlineDriveBrowseFilesRequest,
|
||||||
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
user: Account | EndUser,
|
user: Account | EndUser,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
) -> (
|
||||||
|
WorkflowAppBlockingResponse
|
||||||
|
| WorkflowAppPausedBlockingResponse
|
||||||
|
| Generator[WorkflowAppStreamResponse, None, None]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@ -59,7 +59,7 @@ def stream_topic_events(
|
|||||||
|
|
||||||
|
|
||||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||||
if not terminal_events:
|
if terminal_events is None:
|
||||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||||
values: set[str] = set()
|
values: set[str] = set()
|
||||||
for item in terminal_events:
|
for item in terminal_events:
|
||||||
|
|||||||
@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
|||||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import (
|
||||||
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppPausedBlockingResponse,
|
||||||
|
WorkflowAppStreamResponse,
|
||||||
|
)
|
||||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||||
@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
user: Account | EndUser,
|
user: Account | EndUser,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
) -> (
|
||||||
|
WorkflowAppBlockingResponse
|
||||||
|
| WorkflowAppPausedBlockingResponse
|
||||||
|
| Generator[WorkflowAppStreamResponse, None, None]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
|
|||||||
NodeStartStreamResponse,
|
NodeStartStreamResponse,
|
||||||
PingStreamResponse,
|
PingStreamResponse,
|
||||||
WorkflowAppBlockingResponse,
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppPausedBlockingResponse,
|
||||||
WorkflowAppStreamResponse,
|
WorkflowAppStreamResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
class WorkflowAppGenerateResponseConverter(
|
||||||
_blocking_response_type = WorkflowAppBlockingResponse
|
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||||
|
):
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_full_response(
|
||||||
|
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return blocking_response.model_dump()
|
return dict(blocking_response.model_dump())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
def convert_blocking_simple_response(
|
||||||
|
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
|
|||||||
@ -42,12 +42,15 @@ from core.app.entities.queue_entities import (
|
|||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
ErrorStreamResponse,
|
ErrorStreamResponse,
|
||||||
|
HumanInputRequiredPauseReasonPayload,
|
||||||
|
HumanInputRequiredResponse,
|
||||||
MessageAudioEndStreamResponse,
|
MessageAudioEndStreamResponse,
|
||||||
MessageAudioStreamResponse,
|
MessageAudioStreamResponse,
|
||||||
PingStreamResponse,
|
PingStreamResponse,
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
TextChunkStreamResponse,
|
TextChunkStreamResponse,
|
||||||
WorkflowAppBlockingResponse,
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppPausedBlockingResponse,
|
||||||
WorkflowAppStreamResponse,
|
WorkflowAppStreamResponse,
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowPauseStreamResponse,
|
WorkflowPauseStreamResponse,
|
||||||
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
)
|
)
|
||||||
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
|
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||||
|
|
||||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
def process(
|
||||||
|
self,
|
||||||
|
) -> Union[
|
||||||
|
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Process generate task pipeline.
|
Process generate task pipeline.
|
||||||
:return:
|
:return:
|
||||||
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
|
|
||||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
def _to_blocking_response(
|
||||||
|
self, generator: Generator[StreamResponse, None, None]
|
||||||
|
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
|
||||||
"""
|
"""
|
||||||
To blocking response.
|
To blocking response.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||||
for stream_response in generator:
|
for stream_response in generator:
|
||||||
if isinstance(stream_response, ErrorStreamResponse):
|
if isinstance(stream_response, ErrorStreamResponse):
|
||||||
raise stream_response.err
|
raise stream_response.err
|
||||||
|
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||||
|
human_input_responses.append(stream_response)
|
||||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||||
response = WorkflowAppBlockingResponse(
|
return WorkflowAppPausedBlockingResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run_id=stream_response.data.workflow_run_id,
|
workflow_run_id=stream_response.data.workflow_run_id,
|
||||||
data=WorkflowAppBlockingResponse.Data(
|
data=WorkflowAppPausedBlockingResponse.Data(
|
||||||
id=stream_response.data.workflow_run_id,
|
id=stream_response.data.workflow_run_id,
|
||||||
workflow_id=self._workflow.id,
|
workflow_id=self._workflow.id,
|
||||||
status=stream_response.data.status,
|
status=stream_response.data.status,
|
||||||
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
total_steps=stream_response.data.total_steps,
|
total_steps=stream_response.data.total_steps,
|
||||||
created_at=stream_response.data.created_at,
|
created_at=stream_response.data.created_at,
|
||||||
finished_at=None,
|
finished_at=None,
|
||||||
|
paused_nodes=stream_response.data.paused_nodes,
|
||||||
|
reasons=stream_response.data.reasons,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
|
||||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||||
response = WorkflowAppBlockingResponse(
|
return WorkflowAppBlockingResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run_id=stream_response.data.id,
|
workflow_run_id=stream_response.data.id,
|
||||||
data=WorkflowAppBlockingResponse.Data(
|
data=WorkflowAppBlockingResponse.Data(
|
||||||
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if human_input_responses:
|
||||||
|
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||||
|
|
||||||
raise ValueError("queue listening stopped unexpectedly.")
|
raise ValueError("queue listening stopped unexpectedly.")
|
||||||
|
|
||||||
|
def _build_paused_blocking_response_from_human_input(
|
||||||
|
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||||
|
) -> WorkflowAppPausedBlockingResponse:
|
||||||
|
runtime_state = self._resolve_graph_runtime_state()
|
||||||
|
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||||
|
created_at = int(runtime_state.start_at)
|
||||||
|
reasons = [
|
||||||
|
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||||
|
for response in human_input_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
return WorkflowAppPausedBlockingResponse(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||||
|
data=WorkflowAppPausedBlockingResponse.Data(
|
||||||
|
id=human_input_responses[-1].workflow_run_id,
|
||||||
|
workflow_id=self._workflow.id,
|
||||||
|
status=WorkflowExecutionStatus.PAUSED,
|
||||||
|
outputs={},
|
||||||
|
error=None,
|
||||||
|
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||||
|
total_tokens=runtime_state.total_tokens,
|
||||||
|
total_steps=runtime_state.node_run_steps,
|
||||||
|
created_at=created_at,
|
||||||
|
finished_at=None,
|
||||||
|
paused_nodes=paused_nodes,
|
||||||
|
reasons=reasons,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _to_stream_response(
|
def _to_stream_response(
|
||||||
self, generator: Generator[StreamResponse, None, None]
|
self, generator: Generator[StreamResponse, None, None]
|
||||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||||
|
|
||||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||||
from core.rag.entities import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
from graphon.entities import WorkflowStartReason
|
from graphon.entities import WorkflowStartReason
|
||||||
|
from graphon.entities.pause_reason import PauseReasonType
|
||||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||||
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
|
|||||||
data: Data
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||||
|
"""
|
||||||
|
Public pause-reason payload used by blocking responses when only
|
||||||
|
``human_input_required`` events are available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||||
|
form_id: str
|
||||||
|
node_id: str
|
||||||
|
node_title: str
|
||||||
|
form_content: str
|
||||||
|
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||||
|
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||||
|
display_in_ui: bool = False
|
||||||
|
form_token: str | None = None
|
||||||
|
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||||
|
expiration_time: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
|
||||||
|
return cls(
|
||||||
|
form_id=data.form_id,
|
||||||
|
node_id=data.node_id,
|
||||||
|
node_title=data.node_title,
|
||||||
|
form_content=data.form_content,
|
||||||
|
inputs=data.inputs,
|
||||||
|
actions=data.actions,
|
||||||
|
display_in_ui=data.display_in_ui,
|
||||||
|
form_token=data.form_token,
|
||||||
|
resolved_default_values=data.resolved_default_values,
|
||||||
|
expiration_time=data.expiration_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormFilledResponse(StreamResponse):
|
class HumanInputFormFilledResponse(StreamResponse):
|
||||||
class Data(BaseModel):
|
class Data(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
data: Data
|
data: Data
|
||||||
|
|
||||||
def to_ignore_detail_dict(self):
|
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||||
return {
|
return {
|
||||||
"event": self.event.value,
|
"event": self.event.value,
|
||||||
"task_id": self.task_id,
|
"task_id": self.task_id,
|
||||||
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
data: Data
|
data: Data
|
||||||
|
|
||||||
def to_ignore_detail_dict(self):
|
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||||
return {
|
return {
|
||||||
"event": self.event.value,
|
"event": self.event.value,
|
||||||
"task_id": self.task_id,
|
"task_id": self.task_id,
|
||||||
@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
|||||||
data: Data
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
|
||||||
|
"""
|
||||||
|
ChatbotAppPausedBlockingResponse entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""
|
||||||
|
Data entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
mode: str
|
||||||
|
conversation_id: str
|
||||||
|
message_id: str
|
||||||
|
workflow_run_id: str
|
||||||
|
answer: str
|
||||||
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
created_at: int
|
||||||
|
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||||
|
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
|
||||||
|
status: WorkflowExecutionStatus
|
||||||
|
elapsed_time: float
|
||||||
|
total_tokens: int
|
||||||
|
total_steps: int
|
||||||
|
|
||||||
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
class CompletionAppBlockingResponse(AppBlockingResponse):
|
class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||||
"""
|
"""
|
||||||
CompletionAppBlockingResponse entity
|
CompletionAppBlockingResponse entity
|
||||||
@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||||||
data: Data
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
|
||||||
|
"""
|
||||||
|
WorkflowAppPausedBlockingResponse entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""
|
||||||
|
Data entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
workflow_id: str
|
||||||
|
status: WorkflowExecutionStatus
|
||||||
|
outputs: Mapping[str, Any] | None = None
|
||||||
|
error: str | None = None
|
||||||
|
elapsed_time: float
|
||||||
|
total_tokens: int
|
||||||
|
total_steps: int
|
||||||
|
created_at: int
|
||||||
|
finished_at: int | None
|
||||||
|
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||||
|
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
workflow_run_id: str
|
||||||
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
class AgentLogStreamResponse(StreamResponse):
|
class AgentLogStreamResponse(StreamResponse):
|
||||||
"""
|
"""
|
||||||
AgentLogStreamResponse entity
|
AgentLogStreamResponse entity
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||||
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
|
|||||||
|
|
||||||
|
|
||||||
class DifyCredentialsProvider:
|
class DifyCredentialsProvider:
|
||||||
|
"""Resolves and returns LLM credentials for a given provider and model.
|
||||||
|
|
||||||
|
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||||
|
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||||
|
Because of that cache, a single instance can return stale credentials after
|
||||||
|
the tenant or provider configuration changes (e.g. API key rotation).
|
||||||
|
|
||||||
|
Do **not** keep one instance for the lifetime of a process or across
|
||||||
|
unrelated invocations. Create a new provider per request, workflow run, or
|
||||||
|
other bounded scope where up-to-date credentials matter.
|
||||||
|
"""
|
||||||
|
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
provider_manager: ProviderManager
|
provider_manager: ProviderManager
|
||||||
|
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
|
|||||||
user_id=run_context.user_id,
|
user_id=run_context.user_id,
|
||||||
)
|
)
|
||||||
self.provider_manager = provider_manager
|
self.provider_manager = provider_manager
|
||||||
|
self.credentials_cache = {}
|
||||||
|
|
||||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||||
|
if (provider_name, model_name) in self.credentials_cache:
|
||||||
|
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||||
|
|
||||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||||
provider_configuration = provider_configurations.get(provider_name)
|
provider_configuration = provider_configurations.get(provider_name)
|
||||||
if not provider_configuration:
|
if not provider_configuration:
|
||||||
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
|
|||||||
if credentials is None:
|
if credentials is None:
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
|
|
||||||
|
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +84,8 @@ class DifyModelFactory:
|
|||||||
provider_manager=create_plugin_provider_manager(
|
provider_manager=create_plugin_provider_manager(
|
||||||
tenant_id=run_context.tenant_id,
|
tenant_id=run_context.tenant_id,
|
||||||
user_id=run_context.user_id,
|
user_id=run_context.user_id,
|
||||||
)
|
),
|
||||||
|
enable_credentials_cache=True,
|
||||||
)
|
)
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
|
|
||||||
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
|||||||
tenant_id=run_context.tenant_id,
|
tenant_id=run_context.tenant_id,
|
||||||
user_id=run_context.user_id,
|
user_id=run_context.user_id,
|
||||||
)
|
)
|
||||||
model_manager = ModelManager(provider_manager=provider_manager)
|
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||||
|
|||||||
41
api/core/helper/creators.py
Normal file
41
api/core/helper/creators.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
Helper module for Creators Platform integration.
|
||||||
|
|
||||||
|
Provides functionality to upload DSL files to the Creators Platform
|
||||||
|
and generate redirect URLs with OAuth authorization codes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||||
|
|
||||||
|
|
||||||
|
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||||
|
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||||
|
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
claim_code = data.get("data", {}).get("claim_code")
|
||||||
|
if not claim_code:
|
||||||
|
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||||
|
return claim_code
|
||||||
|
|
||||||
|
|
||||||
|
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||||
|
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||||
|
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||||
|
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||||
|
if client_id:
|
||||||
|
from services.oauth_server import OAuthServerService
|
||||||
|
|
||||||
|
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||||
|
params["oauth_code"] = oauth_code
|
||||||
|
return f"{base_url}?{urlencode(params)}"
|
||||||
@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
|
|||||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||||
from core.llm_generator.prompts import (
|
from core.llm_generator.prompts import (
|
||||||
CONVERSATION_TITLE_PROMPT,
|
CONVERSATION_TITLE_PROMPT,
|
||||||
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
|
|
||||||
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
|
|
||||||
GENERATOR_QA_PROMPT,
|
GENERATOR_QA_PROMPT,
|
||||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||||
LLM_MODIFY_CODE_SYSTEM,
|
LLM_MODIFY_CODE_SYSTEM,
|
||||||
@ -217,8 +215,8 @@ class LLMGenerator:
|
|||||||
else:
|
else:
|
||||||
# Default-model generation keeps the built-in suggested-questions tuning.
|
# Default-model generation keeps the built-in suggested-questions tuning.
|
||||||
model_parameters = {
|
model_parameters = {
|
||||||
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
|
"max_tokens": 2560,
|
||||||
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
|
"temperature": 0.0,
|
||||||
}
|
}
|
||||||
stop = []
|
stop = []
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||||
def __init__(self, instruction_prompt: str | None = None) -> None:
|
def __init__(self, instruction_prompt: str | None = None) -> None:
|
||||||
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
|
||||||
|
if not instruction_prompt or not instruction_prompt.strip():
|
||||||
|
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||||
|
|
||||||
|
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
return self._instruction_prompt
|
return self._instruction_prompt
|
||||||
|
|||||||
@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
|||||||
'["question1","question2","question3"]\n'
|
'["question1","question2","question3"]\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
|
|
||||||
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
|
|
||||||
|
|
||||||
GENERATOR_QA_PROMPT = (
|
GENERATOR_QA_PROMPT = (
|
||||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||||
" in the long text. Please think step by step."
|
" in the long text. Please think step by step."
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||||
|
from copy import deepcopy
|
||||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -36,11 +37,13 @@ class ModelInstance:
|
|||||||
Model instance class.
|
Model instance class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||||
self.provider_model_bundle = provider_model_bundle
|
self.provider_model_bundle = provider_model_bundle
|
||||||
self.model_name = model
|
self.model_name = model
|
||||||
self.provider = provider_model_bundle.configuration.provider.provider
|
self.provider = provider_model_bundle.configuration.provider.provider
|
||||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
if credentials is None:
|
||||||
|
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||||
|
self.credentials = credentials
|
||||||
# Runtime LLM invocation fields.
|
# Runtime LLM invocation fields.
|
||||||
self.parameters: Mapping[str, Any] = {}
|
self.parameters: Mapping[str, Any] = {}
|
||||||
self.stop: Sequence[str] = ()
|
self.stop: Sequence[str] = ()
|
||||||
@ -434,8 +437,30 @@ class ModelInstance:
|
|||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
def __init__(self, provider_manager: ProviderManager):
|
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||||
|
|
||||||
|
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||||
|
``(tenant_id, provider, model_type, model)`` are stored in
|
||||||
|
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||||
|
API keys or provider settings change, so a manager constructed with
|
||||||
|
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||||
|
process or shared across unrelated work. Prefer a new manager per request,
|
||||||
|
workflow run, or similar bounded scope.
|
||||||
|
|
||||||
|
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||||
|
credential cache is not populated, and each ``get_model_instance`` call
|
||||||
|
loads credentials from the current provider configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_manager: ProviderManager,
|
||||||
|
*,
|
||||||
|
enable_credentials_cache: bool = False,
|
||||||
|
) -> None:
|
||||||
self._provider_manager = provider_manager
|
self._provider_manager = provider_manager
|
||||||
|
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||||
|
self._enable_credentials_cache = enable_credentials_cache
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||||
@ -463,8 +488,19 @@ class ModelManager:
|
|||||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
model_instance = ModelInstance(provider_model_bundle, model)
|
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||||
return model_instance
|
|
||||||
|
if cred_cache_key in self._credentials_cache:
|
||||||
|
return ModelInstance(
|
||||||
|
provider_model_bundle,
|
||||||
|
model,
|
||||||
|
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = ModelInstance(provider_model_bundle, model)
|
||||||
|
if self._enable_credentials_cache:
|
||||||
|
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||||
|
return ret
|
||||||
|
|
||||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
|
|||||||
if dataset_keyword_table:
|
if dataset_keyword_table:
|
||||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||||
if keyword_table_dict:
|
if keyword_table_dict:
|
||||||
return dict(keyword_table_dict["__data__"]["table"])
|
data: Any = keyword_table_dict["__data__"]
|
||||||
|
return dict(data["table"])
|
||||||
else:
|
else:
|
||||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||||
dataset_keyword_table = DatasetKeywordTable(
|
dataset_keyword_table = DatasetKeywordTable(
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
|
|||||||
"""Extract keywords with JIEBA tfidf."""
|
"""Extract keywords with JIEBA tfidf."""
|
||||||
keywords = self._tfidf.extract_tags(
|
keywords = self._tfidf.extract_tags(
|
||||||
sentence=text,
|
sentence=text,
|
||||||
topK=max_keywords_per_chunk,
|
topK=max_keywords_per_chunk or 10,
|
||||||
)
|
)
|
||||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||||
keywords = cast(list[str], keywords)
|
keywords = cast(list[str], keywords)
|
||||||
|
|||||||
@ -551,6 +551,7 @@ class RetrievalService:
|
|||||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||||
|
|
||||||
for i in child_index_nodes:
|
for i in child_index_nodes:
|
||||||
|
assert i.index_node_id
|
||||||
segment_ids.append(i.segment_id)
|
segment_ids.append(i.segment_id)
|
||||||
if i.segment_id in child_chunk_map:
|
if i.segment_id in child_chunk_map:
|
||||||
child_chunk_map[i.segment_id].append(i)
|
child_chunk_map[i.segment_id].append(i)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from core.rag.models.document import AttachmentDocument, Document
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||||
|
from models.enums import SegmentType
|
||||||
|
|
||||||
|
|
||||||
class DatasetDocumentStore:
|
class DatasetDocumentStore:
|
||||||
@ -127,6 +128,7 @@ class DatasetDocumentStore:
|
|||||||
if save_child:
|
if save_child:
|
||||||
if doc.children:
|
if doc.children:
|
||||||
for position, child in enumerate(doc.children, start=1):
|
for position, child in enumerate(doc.children, start=1):
|
||||||
|
assert self._document_id
|
||||||
child_segment = ChildChunk(
|
child_segment = ChildChunk(
|
||||||
tenant_id=self._dataset.tenant_id,
|
tenant_id=self._dataset.tenant_id,
|
||||||
dataset_id=self._dataset.id,
|
dataset_id=self._dataset.id,
|
||||||
@ -137,7 +139,7 @@ class DatasetDocumentStore:
|
|||||||
index_node_hash=child.metadata.get("doc_hash"),
|
index_node_hash=child.metadata.get("doc_hash"),
|
||||||
content=child.page_content,
|
content=child.page_content,
|
||||||
word_count=len(child.page_content),
|
word_count=len(child.page_content),
|
||||||
type="automatic",
|
type=SegmentType.AUTOMATIC,
|
||||||
created_by=self._user_id,
|
created_by=self._user_id,
|
||||||
)
|
)
|
||||||
db.session.add(child_segment)
|
db.session.add(child_segment)
|
||||||
@ -163,6 +165,7 @@ class DatasetDocumentStore:
|
|||||||
)
|
)
|
||||||
# add new child chunks
|
# add new child chunks
|
||||||
for position, child in enumerate(doc.children, start=1):
|
for position, child in enumerate(doc.children, start=1):
|
||||||
|
assert self._document_id
|
||||||
child_segment = ChildChunk(
|
child_segment = ChildChunk(
|
||||||
tenant_id=self._dataset.tenant_id,
|
tenant_id=self._dataset.tenant_id,
|
||||||
dataset_id=self._dataset.id,
|
dataset_id=self._dataset.id,
|
||||||
@ -173,7 +176,7 @@ class DatasetDocumentStore:
|
|||||||
index_node_hash=child.metadata.get("doc_hash"),
|
index_node_hash=child.metadata.get("doc_hash"),
|
||||||
content=child.page_content,
|
content=child.page_content,
|
||||||
word_count=len(child.page_content),
|
word_count=len(child.page_content),
|
||||||
type="automatic",
|
type=SegmentType.AUTOMATIC,
|
||||||
created_by=self._user_id,
|
created_by=self._user_id,
|
||||||
)
|
)
|
||||||
db.session.add(child_segment)
|
db.session.add(child_segment)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
|
|||||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
tools=dataset_tools,
|
tools=dataset_tools,
|
||||||
stream=False,
|
stream=False, # pyright: ignore[reportArgumentType]
|
||||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||||
)
|
)
|
||||||
usage = result.usage or LLMUsage.empty_usage()
|
usage = result.usage or LLMUsage.empty_usage()
|
||||||
|
|||||||
@ -14,23 +14,23 @@ from configs import dify_config
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OAuthEncryptionError(Exception):
|
class EncryptionError(Exception):
|
||||||
"""OAuth encryption/decryption specific error"""
|
"""Encryption/decryption specific error"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SystemOAuthEncrypter:
|
class SystemEncrypter:
|
||||||
"""
|
"""
|
||||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
A simple parameters encrypter using AES-CBC encryption.
|
||||||
|
|
||||||
This class provides methods to encrypt and decrypt OAuth parameters
|
This class provides methods to encrypt and decrypt parameters
|
||||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, secret_key: str | None = None):
|
def __init__(self, secret_key: str | None = None):
|
||||||
"""
|
"""
|
||||||
Initialize the OAuth encrypter.
|
Initialize the encrypter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||||
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
|
|||||||
# Generate a fixed 256-bit key using SHA-256
|
# Generate a fixed 256-bit key using SHA-256
|
||||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||||
|
|
||||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Encrypt OAuth parameters.
|
Encrypt parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Base64-encoded encrypted string
|
Base64-encoded encrypted string
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
OAuthEncryptionError: If encryption fails
|
EncryptionError: If encryption fails
|
||||||
ValueError: If oauth_params is invalid
|
ValueError: If params is invalid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
|
|||||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||||
|
|
||||||
# Encrypt data
|
# Encrypt data
|
||||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||||
encrypted_data = cipher.encrypt(padded_data)
|
encrypted_data = cipher.encrypt(padded_data)
|
||||||
|
|
||||||
# Combine IV and encrypted data
|
# Combine IV and encrypted data
|
||||||
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
|
|||||||
return base64.b64encode(combined).decode()
|
return base64.b64encode(combined).decode()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||||
|
|
||||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||||
"""
|
"""
|
||||||
Decrypt OAuth parameters.
|
Decrypt parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encrypted_data: Base64-encoded encrypted string
|
encrypted_data: Base64-encoded encrypted string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decrypted OAuth parameters dictionary
|
Decrypted parameters dictionary
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
OAuthEncryptionError: If decryption fails
|
EncryptionError: If decryption fails
|
||||||
ValueError: If encrypted_data is invalid
|
ValueError: If encrypted_data is invalid
|
||||||
"""
|
"""
|
||||||
if not isinstance(encrypted_data, str):
|
if not isinstance(encrypted_data, str):
|
||||||
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
|
|||||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||||
|
|
||||||
# Parse JSON
|
# Parse JSON
|
||||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||||
|
|
||||||
if not isinstance(oauth_params, dict):
|
if not isinstance(params, dict):
|
||||||
raise ValueError("Decrypted data is not a valid dictionary")
|
raise ValueError("Decrypted data is not a valid dictionary")
|
||||||
|
|
||||||
return oauth_params
|
return params
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# Factory function for creating encrypter instances
|
# Factory function for creating encrypter instances
|
||||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||||
"""
|
"""
|
||||||
Create an OAuth encrypter instance.
|
Create an encrypter instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SystemOAuthEncrypter instance
|
SystemEncrypter instance
|
||||||
"""
|
"""
|
||||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
return SystemEncrypter(secret_key=secret_key)
|
||||||
|
|
||||||
|
|
||||||
# Global encrypter instance (for backward compatibility)
|
# Global encrypter instance (for backward compatibility)
|
||||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
_encrypter: SystemEncrypter | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
def get_system_encrypter() -> SystemEncrypter:
|
||||||
"""
|
"""
|
||||||
Get the global OAuth encrypter instance.
|
Get the global encrypter instance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SystemOAuthEncrypter instance
|
SystemEncrypter instance
|
||||||
"""
|
"""
|
||||||
global _oauth_encrypter
|
global _encrypter
|
||||||
if _oauth_encrypter is None:
|
if _encrypter is None:
|
||||||
_oauth_encrypter = SystemOAuthEncrypter()
|
_encrypter = SystemEncrypter()
|
||||||
return _oauth_encrypter
|
return _encrypter
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions for backward compatibility
|
# Convenience functions for backward compatibility
|
||||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Encrypt OAuth parameters using the global encrypter.
|
Encrypt parameters using the global encrypter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
oauth_params: OAuth parameters dictionary
|
params: Parameters dictionary
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Base64-encoded encrypted string
|
Base64-encoded encrypted string
|
||||||
"""
|
"""
|
||||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
return get_system_encrypter().encrypt_params(params)
|
||||||
|
|
||||||
|
|
||||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||||
"""
|
"""
|
||||||
Decrypt OAuth parameters using the global encrypter.
|
Decrypt parameters using the global encrypter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encrypted_data: Base64-encoded encrypted string
|
encrypted_data: Base64-encoded encrypted string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decrypted OAuth parameters dictionary
|
Decrypted parameters dictionary
|
||||||
"""
|
"""
|
||||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||||
@ -12,20 +12,16 @@ from collections.abc import Sequence
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||||
|
|
||||||
_FORM_TOKEN_PRIORITY = {
|
|
||||||
RecipientType.BACKSTAGE: 0,
|
|
||||||
RecipientType.CONSOLE: 1,
|
|
||||||
RecipientType.STANDALONE_WEB_APP: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_form_tokens_by_form_id(
|
def load_form_tokens_by_form_id(
|
||||||
form_ids: Sequence[str],
|
form_ids: Sequence[str],
|
||||||
*,
|
*,
|
||||||
session: Session | None = None,
|
session: Session | None = None,
|
||||||
|
surface: HumanInputSurface | None = None,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Load the preferred access token for each human input form."""
|
"""Load the preferred access token for each human input form."""
|
||||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||||
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
if session is not None:
|
if session is not None:
|
||||||
return _load_form_tokens_by_form_id(session, unique_form_ids)
|
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||||
|
|
||||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
|
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||||
|
|
||||||
|
|
||||||
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
|
def _load_form_tokens_by_form_id(
|
||||||
tokens_by_form_id: dict[str, tuple[int, str]] = {}
|
session: Session,
|
||||||
|
form_ids: Sequence[str],
|
||||||
|
*,
|
||||||
|
surface: HumanInputSurface | None = None,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||||
for recipient in session.scalars(stmt):
|
for recipient in session.scalars(stmt):
|
||||||
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
|
if not recipient.access_token:
|
||||||
if priority is None or not recipient.access_token:
|
|
||||||
continue
|
continue
|
||||||
|
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||||
|
(recipient.recipient_type, recipient.access_token)
|
||||||
|
)
|
||||||
|
|
||||||
candidate = (priority, recipient.access_token)
|
tokens_by_form_id: dict[str, str] = {}
|
||||||
current = tokens_by_form_id.get(recipient.form_id)
|
for form_id, recipients in recipients_by_form_id.items():
|
||||||
if current is None or candidate[0] < current[0]:
|
token = _get_surface_form_token(recipients, surface=surface)
|
||||||
tokens_by_form_id[recipient.form_id] = candidate
|
if token is not None:
|
||||||
|
tokens_by_form_id[form_id] = token
|
||||||
|
return tokens_by_form_id
|
||||||
|
|
||||||
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
|
|
||||||
|
def _get_surface_form_token(
|
||||||
|
recipients: Sequence[tuple[RecipientType, str]],
|
||||||
|
*,
|
||||||
|
surface: HumanInputSurface | None,
|
||||||
|
) -> str | None:
|
||||||
|
if surface == HumanInputSurface.SERVICE_API:
|
||||||
|
for recipient_type, token in recipients:
|
||||||
|
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||||
|
return token
|
||||||
|
|
||||||
|
return get_preferred_form_token(recipients)
|
||||||
|
|||||||
73
api/core/workflow/human_input_policy.py
Normal file
73
api/core/workflow/human_input_policy.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from graphon.entities.pause_reason import PauseReasonType
|
||||||
|
from models.human_input import RecipientType
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInputSurface(StrEnum):
|
||||||
|
SERVICE_API = "service_api"
|
||||||
|
CONSOLE = "console"
|
||||||
|
|
||||||
|
|
||||||
|
# Service API is intentionally narrower than other surfaces: app-token callers
|
||||||
|
# should only be able to act on end-user web forms, not internal console flows.
|
||||||
|
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||||
|
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||||
|
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||||
|
}
|
||||||
|
|
||||||
|
# A single HITL form can have multiple recipient records; this shared priority
|
||||||
|
# keeps every API surface consistent about which resume token to expose.
|
||||||
|
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
|
||||||
|
RecipientType.BACKSTAGE: 0,
|
||||||
|
RecipientType.CONSOLE: 1,
|
||||||
|
RecipientType.STANDALONE_WEB_APP: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_recipient_type_allowed_for_surface(
|
||||||
|
recipient_type: RecipientType | None,
|
||||||
|
surface: HumanInputSurface,
|
||||||
|
) -> bool:
|
||||||
|
if recipient_type is None:
|
||||||
|
return False
|
||||||
|
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||||
|
|
||||||
|
|
||||||
|
def get_preferred_form_token(
|
||||||
|
recipients: Sequence[tuple[RecipientType, str]],
|
||||||
|
) -> str | None:
|
||||||
|
chosen_token: str | None = None
|
||||||
|
chosen_priority: int | None = None
|
||||||
|
for recipient_type, token in recipients:
|
||||||
|
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
|
||||||
|
if priority is None or not token:
|
||||||
|
continue
|
||||||
|
if chosen_priority is None or priority < chosen_priority:
|
||||||
|
chosen_priority = priority
|
||||||
|
chosen_token = token
|
||||||
|
return chosen_token
|
||||||
|
|
||||||
|
|
||||||
|
def enrich_human_input_pause_reasons(
|
||||||
|
reasons: Sequence[Mapping[str, Any]],
|
||||||
|
*,
|
||||||
|
form_tokens_by_form_id: Mapping[str, str],
|
||||||
|
expiration_times_by_form_id: Mapping[str, int],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
enriched: list[dict[str, Any]] = []
|
||||||
|
for reason in reasons:
|
||||||
|
updated = dict(reason)
|
||||||
|
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||||
|
form_id = updated.get("form_id")
|
||||||
|
if isinstance(form_id, str):
|
||||||
|
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||||
|
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||||
|
if expiration_time is not None:
|
||||||
|
updated["expiration_time"] = expiration_time
|
||||||
|
enriched.append(updated)
|
||||||
|
return enriched
|
||||||
@ -1,56 +1,17 @@
|
|||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class QuotaCharge:
|
|
||||||
"""
|
|
||||||
Result of a quota consumption operation.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
success: Whether the quota charge succeeded
|
|
||||||
charge_id: UUID for refund, or None if failed/disabled
|
|
||||||
"""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
charge_id: str | None
|
|
||||||
_quota_type: "QuotaType"
|
|
||||||
|
|
||||||
def refund(self) -> None:
|
|
||||||
"""
|
|
||||||
Refund this quota charge.
|
|
||||||
|
|
||||||
Safe to call even if charge failed or was disabled.
|
|
||||||
This method guarantees no exceptions will be raised.
|
|
||||||
"""
|
|
||||||
if self.charge_id:
|
|
||||||
self._quota_type.refund(self.charge_id)
|
|
||||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaType(StrEnum):
|
class QuotaType(StrEnum):
|
||||||
"""
|
"""
|
||||||
Supported quota types for tenant feature usage.
|
Supported quota types for tenant feature usage.
|
||||||
|
|
||||||
Add additional types here whenever new billable features become available.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Trigger execution quota
|
|
||||||
TRIGGER = auto()
|
TRIGGER = auto()
|
||||||
|
|
||||||
# Workflow execution quota
|
|
||||||
WORKFLOW = auto()
|
WORKFLOW = auto()
|
||||||
|
|
||||||
UNLIMITED = auto()
|
UNLIMITED = auto()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def billing_key(self) -> str:
|
def billing_key(self) -> str:
|
||||||
"""
|
|
||||||
Get the billing key for the feature.
|
|
||||||
"""
|
|
||||||
match self:
|
match self:
|
||||||
case QuotaType.TRIGGER:
|
case QuotaType.TRIGGER:
|
||||||
return "trigger_event"
|
return "trigger_event"
|
||||||
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
|
|||||||
return "api_rate_limit"
|
return "api_rate_limit"
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid quota type: {self}")
|
raise ValueError(f"Invalid quota type: {self}")
|
||||||
|
|
||||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
|
||||||
"""
|
|
||||||
Consume quota for the feature.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: The tenant identifier
|
|
||||||
amount: Amount to consume (default: 1)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
QuotaCharge with success status and charge_id for refund
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
QuotaExceededError: When quota is insufficient
|
|
||||||
"""
|
|
||||||
from configs import dify_config
|
|
||||||
from services.billing_service import BillingService
|
|
||||||
from services.errors.app import QuotaExceededError
|
|
||||||
|
|
||||||
if not dify_config.BILLING_ENABLED:
|
|
||||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
|
||||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
|
||||||
|
|
||||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
|
||||||
|
|
||||||
if amount <= 0:
|
|
||||||
raise ValueError("Amount to consume must be greater than 0")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
|
||||||
|
|
||||||
if response.get("result") != "success":
|
|
||||||
logger.warning(
|
|
||||||
"Failed to consume quota for %s, feature %s details: %s",
|
|
||||||
tenant_id,
|
|
||||||
self.value,
|
|
||||||
response.get("detail"),
|
|
||||||
)
|
|
||||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
|
||||||
|
|
||||||
charge_id = response.get("history_id")
|
|
||||||
logger.debug(
|
|
||||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
|
||||||
amount,
|
|
||||||
self.value,
|
|
||||||
tenant_id,
|
|
||||||
charge_id,
|
|
||||||
)
|
|
||||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
|
||||||
|
|
||||||
except QuotaExceededError:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
# fail-safe: allow request on billing errors
|
|
||||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
|
||||||
return unlimited()
|
|
||||||
|
|
||||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
|
||||||
"""
|
|
||||||
Check if tenant has sufficient quota without consuming.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: The tenant identifier
|
|
||||||
amount: Amount to check (default: 1)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if quota is sufficient, False otherwise
|
|
||||||
"""
|
|
||||||
from configs import dify_config
|
|
||||||
|
|
||||||
if not dify_config.BILLING_ENABLED:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if amount <= 0:
|
|
||||||
raise ValueError("Amount to check must be greater than 0")
|
|
||||||
|
|
||||||
try:
|
|
||||||
remaining = self.get_remaining(tenant_id)
|
|
||||||
return remaining >= amount if remaining != -1 else True
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
|
||||||
# fail-safe: allow request on billing errors
|
|
||||||
return True
|
|
||||||
|
|
||||||
def refund(self, charge_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Refund quota using charge_id from consume().
|
|
||||||
|
|
||||||
This method guarantees no exceptions will be raised.
|
|
||||||
All errors are logged but silently handled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
charge_id: The UUID returned from consume()
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from configs import dify_config
|
|
||||||
from services.billing_service import BillingService
|
|
||||||
|
|
||||||
if not dify_config.BILLING_ENABLED:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not charge_id:
|
|
||||||
logger.warning("Cannot refund: charge_id is empty")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
|
||||||
|
|
||||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
|
||||||
if response.get("result") == "success":
|
|
||||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
|
||||||
else:
|
|
||||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# Catch ALL exceptions - refund must never fail
|
|
||||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
|
||||||
# Don't raise - refund is best-effort and must be silent
|
|
||||||
|
|
||||||
def get_remaining(self, tenant_id: str) -> int:
|
|
||||||
"""
|
|
||||||
Get remaining quota for the tenant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: The tenant identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Remaining quota amount
|
|
||||||
"""
|
|
||||||
from services.billing_service import BillingService
|
|
||||||
|
|
||||||
try:
|
|
||||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
|
||||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
|
||||||
if isinstance(usage_info, dict):
|
|
||||||
return usage_info.get("remaining", 0)
|
|
||||||
# If it returns a simple number, treat it as remaining
|
|
||||||
return int(usage_info) if usage_info else 0
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
def unlimited() -> QuotaCharge:
|
|
||||||
"""
|
|
||||||
Return a quota charge for unlimited quota.
|
|
||||||
|
|
||||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
|
||||||
"""
|
|
||||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
|
||||||
|
|||||||
@ -1036,7 +1036,7 @@ class DocumentSegment(Base):
|
|||||||
return attachment_list
|
return attachment_list
|
||||||
|
|
||||||
|
|
||||||
class ChildChunk(Base):
|
class ChildChunk(TypeBase):
|
||||||
__tablename__ = "child_chunks"
|
__tablename__ = "child_chunks"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||||
@ -1046,29 +1046,42 @@ class ChildChunk(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# initial fields
|
# initial fields
|
||||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
|
||||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
document_id = mapped_column(StringUUID, nullable=False)
|
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
segment_id = mapped_column(StringUUID, nullable=False)
|
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||||
content = mapped_column(LongText, nullable=False)
|
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||||
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||||
# indexing fields
|
# indexing fields
|
||||||
index_node_id = mapped_column(String(255), nullable=True)
|
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
index_node_hash = mapped_column(String(255), nullable=True)
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
type: Mapped[SegmentType] = mapped_column(
|
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||||
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
|
|
||||||
)
|
)
|
||||||
created_by = mapped_column(StringUUID, nullable=False)
|
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.current_timestamp(),
|
||||||
|
onupdate=func.current_timestamp(),
|
||||||
|
init=False,
|
||||||
)
|
)
|
||||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
indexing_at: Mapped[datetime | None] = mapped_column(
|
||||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
DateTime, nullable=True, insert_default=None, server_default=None, init=False
|
||||||
error = mapped_column(LongText, nullable=True)
|
)
|
||||||
|
completed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime, nullable=True, insert_default=None, server_default=None, init=False
|
||||||
|
)
|
||||||
|
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||||
|
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||||
|
type: Mapped[SegmentType] = mapped_column(
|
||||||
|
EnumText(SegmentType, length=255),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'automatic'"),
|
||||||
|
default=SegmentType.AUTOMATIC,
|
||||||
|
)
|
||||||
|
error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
|
|||||||
@ -1867,15 +1867,18 @@ class MessageAnnotation(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
StringUUID,
|
||||||
|
insert_default=lambda: str(uuid4()),
|
||||||
|
default_factory=lambda: str(uuid4()),
|
||||||
|
init=False,
|
||||||
)
|
)
|
||||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||||
question: Mapped[str] = mapped_column(LongText, nullable=False)
|
question: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||||
|
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False)
|
||||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
|
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
|
||||||
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
)
|
)
|
||||||
|
|||||||
@ -225,8 +225,10 @@ class TestSpanBuilder:
|
|||||||
span = builder.build_span(span_data)
|
span = builder.build_span(span_data)
|
||||||
assert isinstance(span, ReadableSpan)
|
assert isinstance(span, ReadableSpan)
|
||||||
assert span.name == "test-span"
|
assert span.name == "test-span"
|
||||||
|
assert span.context is not None
|
||||||
assert span.context.trace_id == 123
|
assert span.context.trace_id == 123
|
||||||
assert span.context.span_id == 456
|
assert span.context.span_id == 456
|
||||||
|
assert span.parent is not None
|
||||||
assert span.parent.span_id == 789
|
assert span.parent.span_id == 789
|
||||||
assert span.resource == resource
|
assert span.resource == resource
|
||||||
assert span.attributes == {"attr1": "val1"}
|
assert span.attributes == {"attr1": "val1"}
|
||||||
|
|||||||
@ -64,12 +64,13 @@ class TestSpanData:
|
|||||||
|
|
||||||
def test_span_data_missing_required_fields(self):
|
def test_span_data_missing_required_fields(self):
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
SpanData(
|
SpanData.model_validate(
|
||||||
trace_id=123,
|
{
|
||||||
# span_id missing
|
"trace_id": 123,
|
||||||
name="test_span",
|
"name": "test_span",
|
||||||
start_time=1000,
|
"start_time": 1000,
|
||||||
end_time=2000,
|
"end_time": 2000,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_span_data_arbitrary_types_allowed(self):
|
def test_span_data_arbitrary_types_allowed(self):
|
||||||
|
|||||||
@ -2,12 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from typing import cast
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
|
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
|
||||||
import pytest
|
import pytest
|
||||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||||
from dify_trace_aliyun.config import AliyunConfig
|
from dify_trace_aliyun.config import AliyunConfig
|
||||||
|
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||||
from dify_trace_aliyun.entities.semconv import (
|
from dify_trace_aliyun.entities.semconv import (
|
||||||
GEN_AI_COMPLETION,
|
GEN_AI_COMPLETION,
|
||||||
GEN_AI_INPUT_MESSAGE,
|
GEN_AI_INPUT_MESSAGE,
|
||||||
@ -44,7 +46,7 @@ class RecordingTraceClient:
|
|||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.added_spans: list[object] = []
|
self.added_spans: list[object] = []
|
||||||
|
|
||||||
def add_span(self, span) -> None:
|
def add_span(self, span: object) -> None:
|
||||||
self.added_spans.append(span)
|
self.added_spans.append(span)
|
||||||
|
|
||||||
def api_check(self) -> bool:
|
def api_check(self) -> bool:
|
||||||
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
|
|||||||
trace_id=trace_id,
|
trace_id=trace_id,
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
is_remote=False,
|
is_remote=False,
|
||||||
trace_flags=TraceFlags.SAMPLED,
|
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||||
)
|
)
|
||||||
return Link(context)
|
return Link(context)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_trace_metadata(
|
||||||
|
trace_id: int = 1,
|
||||||
|
workflow_span_id: int = 2,
|
||||||
|
session_id: str = "s",
|
||||||
|
user_id: str = "u",
|
||||||
|
links: list[Link] | None = None,
|
||||||
|
) -> TraceMetadata:
|
||||||
|
return TraceMetadata(
|
||||||
|
trace_id=trace_id,
|
||||||
|
workflow_span_id=workflow_span_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
links=[] if links is None else links,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
|
||||||
|
return cast(RecordingTraceClient, trace_instance.trace_client)
|
||||||
|
|
||||||
|
|
||||||
|
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
|
||||||
|
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
|
||||||
|
|
||||||
|
|
||||||
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
|
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
|
||||||
defaults = {
|
defaults = {
|
||||||
"workflow_id": "workflow-id",
|
"workflow_id": "workflow-id",
|
||||||
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
|
|||||||
trace_instance.workflow_trace(trace_info)
|
trace_instance.workflow_trace(trace_info)
|
||||||
|
|
||||||
add_workflow_span.assert_called_once()
|
add_workflow_span.assert_called_once()
|
||||||
passed_trace_metadata = add_workflow_span.call_args.args[1]
|
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
|
||||||
assert passed_trace_metadata.trace_id == 111
|
assert passed_trace_metadata.trace_id == 111
|
||||||
assert passed_trace_metadata.workflow_span_id == 222
|
assert passed_trace_metadata.workflow_span_id == 222
|
||||||
assert passed_trace_metadata.session_id == "c"
|
assert passed_trace_metadata.session_id == "c"
|
||||||
assert passed_trace_metadata.user_id == "u"
|
assert passed_trace_metadata.user_id == "u"
|
||||||
assert passed_trace_metadata.links == []
|
assert passed_trace_metadata.links == []
|
||||||
|
|
||||||
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
|
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
|
||||||
|
|
||||||
|
|
||||||
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||||
trace_info = _make_message_trace_info(message_data=None)
|
trace_info = _make_message_trace_info(message_data=None)
|
||||||
trace_instance.message_trace(trace_info)
|
trace_instance.message_trace(trace_info)
|
||||||
assert trace_instance.trace_client.added_spans == []
|
assert _recording_trace_client(trace_instance).added_spans == []
|
||||||
|
|
||||||
|
|
||||||
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||||
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
|||||||
)
|
)
|
||||||
trace_instance.message_trace(trace_info)
|
trace_instance.message_trace(trace_info)
|
||||||
|
|
||||||
assert len(trace_instance.trace_client.added_spans) == 2
|
spans = _recorded_span_data(trace_instance)
|
||||||
message_span, llm_span = trace_instance.trace_client.added_spans
|
assert len(spans) == 2
|
||||||
|
message_span, llm_span = spans
|
||||||
|
|
||||||
assert message_span.name == "message"
|
assert message_span.name == "message"
|
||||||
assert message_span.trace_id == 10
|
assert message_span.trace_id == 10
|
||||||
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
|
|||||||
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||||
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
|
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
|
||||||
trace_instance.dataset_retrieval_trace(trace_info)
|
trace_instance.dataset_retrieval_trace(trace_info)
|
||||||
assert trace_instance.trace_client.added_spans == []
|
assert _recording_trace_client(trace_instance).added_spans == []
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||||
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
|||||||
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
|
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
|
||||||
|
|
||||||
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
|
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
|
||||||
assert len(trace_instance.trace_client.added_spans) == 1
|
spans = _recorded_span_data(trace_instance)
|
||||||
span = trace_instance.trace_client.added_spans[0]
|
assert len(spans) == 1
|
||||||
|
span = spans[0]
|
||||||
assert span.name == "dataset_retrieval"
|
assert span.name == "dataset_retrieval"
|
||||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
|
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
|
||||||
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
|
|||||||
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||||
trace_info = _make_tool_trace_info(message_data=None)
|
trace_info = _make_tool_trace_info(message_data=None)
|
||||||
trace_instance.tool_trace(trace_info)
|
trace_instance.tool_trace(trace_info)
|
||||||
assert trace_instance.trace_client.added_spans == []
|
assert _recording_trace_client(trace_instance).added_spans == []
|
||||||
|
|
||||||
|
|
||||||
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||||
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(trace_instance.trace_client.added_spans) == 1
|
spans = _recorded_span_data(trace_instance)
|
||||||
span = trace_instance.trace_client.added_spans[0]
|
assert len(spans) == 1
|
||||||
|
span = spans[0]
|
||||||
assert span.name == "my-tool"
|
assert span.name == "my-tool"
|
||||||
assert span.status == status
|
assert span.status == status
|
||||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||||
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
|
|||||||
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
trace_info = _make_workflow_trace_info()
|
trace_info = _make_workflow_trace_info()
|
||||||
trace_metadata = MagicMock()
|
trace_metadata = _make_trace_metadata()
|
||||||
|
|
||||||
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
|
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
|
||||||
|
|
||||||
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
|||||||
):
|
):
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
trace_info = _make_workflow_trace_info()
|
trace_info = _make_workflow_trace_info()
|
||||||
trace_metadata = MagicMock()
|
trace_metadata = _make_trace_metadata()
|
||||||
|
|
||||||
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
|
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
|
||||||
|
|
||||||
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
|||||||
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
trace_info = _make_workflow_trace_info()
|
trace_info = _make_workflow_trace_info()
|
||||||
trace_metadata = MagicMock()
|
trace_metadata = _make_trace_metadata()
|
||||||
|
|
||||||
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
|
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
|
||||||
|
|
||||||
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
|
|||||||
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
trace_info = _make_workflow_trace_info()
|
trace_info = _make_workflow_trace_info()
|
||||||
trace_metadata = MagicMock()
|
trace_metadata = _make_trace_metadata()
|
||||||
|
|
||||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
|
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
|
||||||
|
|
||||||
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
|
|||||||
):
|
):
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
trace_info = _make_workflow_trace_info()
|
trace_info = _make_workflow_trace_info()
|
||||||
trace_metadata = MagicMock()
|
trace_metadata = _make_trace_metadata()
|
||||||
|
|
||||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
|
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||||
node_execution.node_type = BuiltinNodeTypes.CODE
|
node_execution.node_type = BuiltinNodeTypes.CODE
|
||||||
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
|
|||||||
status = Status(StatusCode.OK)
|
status = Status(StatusCode.OK)
|
||||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||||
|
|
||||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
trace_metadata = _make_trace_metadata()
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
node_execution.id = "node-id"
|
node_execution.id = "node-id"
|
||||||
node_execution.title = "title"
|
node_execution.title = "title"
|
||||||
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
|
|||||||
status = Status(StatusCode.OK)
|
status = Status(StatusCode.OK)
|
||||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||||
|
|
||||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
|
trace_metadata = _make_trace_metadata(links=[_make_link()])
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
node_execution.id = "node-id"
|
node_execution.id = "node-id"
|
||||||
node_execution.title = "my-tool"
|
node_execution.title = "my-tool"
|
||||||
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
|
|||||||
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
|
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
|
||||||
)
|
)
|
||||||
|
|
||||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
trace_metadata = _make_trace_metadata()
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
node_execution.id = "node-id"
|
node_execution.id = "node-id"
|
||||||
node_execution.title = "retrieval"
|
node_execution.title = "retrieval"
|
||||||
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
|
|||||||
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
|
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
|
||||||
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
|
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
|
||||||
|
|
||||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
trace_metadata = _make_trace_metadata()
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
node_execution.id = "node-id"
|
node_execution.id = "node-id"
|
||||||
node_execution.title = "llm"
|
node_execution.title = "llm"
|
||||||
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
|||||||
status = Status(StatusCode.OK)
|
status = Status(StatusCode.OK)
|
||||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||||
|
|
||||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
trace_metadata = _make_trace_metadata()
|
||||||
|
|
||||||
# CASE 1: With message_id
|
# CASE 1: With message_id
|
||||||
trace_info = _make_workflow_trace_info(
|
trace_info = _make_workflow_trace_info(
|
||||||
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
|||||||
)
|
)
|
||||||
trace_instance.add_workflow_span(trace_info, trace_metadata)
|
trace_instance.add_workflow_span(trace_info, trace_metadata)
|
||||||
|
|
||||||
assert len(trace_instance.trace_client.added_spans) == 2
|
client = _recording_trace_client(trace_instance)
|
||||||
message_span = trace_instance.trace_client.added_spans[0]
|
spans = _recorded_span_data(trace_instance)
|
||||||
workflow_span = trace_instance.trace_client.added_spans[1]
|
assert len(spans) == 2
|
||||||
|
message_span = spans[0]
|
||||||
|
workflow_span = spans[1]
|
||||||
|
|
||||||
assert message_span.name == "message"
|
assert message_span.name == "message"
|
||||||
assert message_span.span_kind == SpanKind.SERVER
|
assert message_span.span_kind == SpanKind.SERVER
|
||||||
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
|
|||||||
assert workflow_span.span_kind == SpanKind.INTERNAL
|
assert workflow_span.span_kind == SpanKind.INTERNAL
|
||||||
assert workflow_span.parent_span_id == 20
|
assert workflow_span.parent_span_id == 20
|
||||||
|
|
||||||
trace_instance.trace_client.added_spans.clear()
|
client.added_spans.clear()
|
||||||
|
|
||||||
# CASE 2: Without message_id
|
# CASE 2: Without message_id
|
||||||
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
|
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
|
||||||
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
|
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
|
||||||
assert len(trace_instance.trace_client.added_spans) == 1
|
spans = _recorded_span_data(trace_instance)
|
||||||
span = trace_instance.trace_client.added_spans[0]
|
assert len(spans) == 1
|
||||||
|
span = spans[0]
|
||||||
assert span.name == "workflow"
|
assert span.name == "workflow"
|
||||||
assert span.span_kind == SpanKind.SERVER
|
assert span.span_kind == SpanKind.SERVER
|
||||||
assert span.parent_span_id is None
|
assert span.parent_span_id is None
|
||||||
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
|
|||||||
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
|
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
|
||||||
trace_instance.suggested_question_trace(trace_info)
|
trace_instance.suggested_question_trace(trace_info)
|
||||||
|
|
||||||
assert len(trace_instance.trace_client.added_spans) == 1
|
spans = _recorded_span_data(trace_instance)
|
||||||
span = trace_instance.trace_client.added_spans[0]
|
assert len(spans) == 1
|
||||||
|
span = spans[0]
|
||||||
assert span.name == "suggested_question"
|
assert span.name == "suggested_question"
|
||||||
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'
|
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, cast
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from dify_trace_aliyun.entities.semconv import (
|
from dify_trace_aliyun.entities.semconv import (
|
||||||
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
|
|||||||
|
|
||||||
def test_format_retrieval_documents():
|
def test_format_retrieval_documents():
|
||||||
# Not a list
|
# Not a list
|
||||||
assert format_retrieval_documents("not a list") == []
|
assert format_retrieval_documents(cast(list[object], "not a list")) == []
|
||||||
|
|
||||||
# Valid list
|
# Valid list
|
||||||
docs = [
|
docs = [
|
||||||
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
|
|||||||
|
|
||||||
def test_format_input_messages():
|
def test_format_input_messages():
|
||||||
# Not a dict
|
# Not a dict
|
||||||
assert format_input_messages(None) == serialize_json_data([])
|
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||||
|
|
||||||
# No prompts
|
# No prompts
|
||||||
assert format_input_messages({}) == serialize_json_data([])
|
assert format_input_messages({}) == serialize_json_data([])
|
||||||
@ -244,7 +246,7 @@ def test_format_input_messages():
|
|||||||
|
|
||||||
def test_format_output_messages():
|
def test_format_output_messages():
|
||||||
# Not a dict
|
# Not a dict
|
||||||
assert format_output_messages(None) == serialize_json_data([])
|
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
|
||||||
|
|
||||||
# No text
|
# No text
|
||||||
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
|
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
|
||||||
|
|||||||
@ -25,13 +25,13 @@ class TestAliyunConfig:
|
|||||||
def test_missing_required_fields(self):
|
def test_missing_required_fields(self):
|
||||||
"""Test that required fields are enforced"""
|
"""Test that required fields are enforced"""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
AliyunConfig()
|
AliyunConfig.model_validate({})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
AliyunConfig(license_key="test_license")
|
AliyunConfig.model_validate({"license_key": "test_license"})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
|
||||||
|
|
||||||
def test_app_name_validation_empty(self):
|
def test_app_name_validation_empty(self):
|
||||||
"""Test app_name validation with empty value"""
|
"""Test app_name validation with empty value"""
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import cast
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -129,7 +130,7 @@ def test_set_span_status():
|
|||||||
return "SilentErrorRepr"
|
return "SilentErrorRepr"
|
||||||
|
|
||||||
span.reset_mock()
|
span.reset_mock()
|
||||||
set_span_status(span, SilentError())
|
set_span_status(span, cast(Exception | str | None, SilentError()))
|
||||||
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
|
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,13 +28,13 @@ class TestLangfuseConfig:
|
|||||||
def test_missing_required_fields(self):
|
def test_missing_required_fields(self):
|
||||||
"""Test that required fields are enforced"""
|
"""Test that required fields are enforced"""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
LangfuseConfig()
|
LangfuseConfig.model_validate({})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
LangfuseConfig(public_key="public")
|
LangfuseConfig.model_validate({"public_key": "public"})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
LangfuseConfig(secret_key="secret")
|
LangfuseConfig.model_validate({"secret_key": "secret"})
|
||||||
|
|
||||||
def test_host_validation_empty(self):
|
def test_host_validation_empty(self):
|
||||||
"""Test host validation with empty value"""
|
"""Test host validation with empty value"""
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from typing import cast
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from dify_trace_langfuse.config import LangfuseConfig
|
from dify_trace_langfuse.config import LangfuseConfig
|
||||||
@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
|
|||||||
|
|
||||||
assert trace._get_completion_start_time(start_time, None) is None
|
assert trace._get_completion_start_time(start_time, None) is None
|
||||||
assert trace._get_completion_start_time(start_time, -1) is None
|
assert trace._get_completion_start_time(start_time, -1) is None
|
||||||
assert trace._get_completion_start_time(start_time, "invalid") is None
|
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None
|
||||||
|
|||||||
@ -21,13 +21,13 @@ class TestLangSmithConfig:
|
|||||||
def test_missing_required_fields(self):
|
def test_missing_required_fields(self):
|
||||||
"""Test that required fields are enforced"""
|
"""Test that required fields are enforced"""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
LangSmithConfig()
|
LangSmithConfig.model_validate({})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
LangSmithConfig(api_key="key")
|
LangSmithConfig.model_validate({"api_key": "key"})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
LangSmithConfig(project="project")
|
LangSmithConfig.model_validate({"project": "project"})
|
||||||
|
|
||||||
def test_endpoint_validation_https_only(self):
|
def test_endpoint_validation_https_only(self):
|
||||||
"""Test endpoint validation only allows HTTPS"""
|
"""Test endpoint validation only allows HTTPS"""
|
||||||
|
|||||||
@ -599,7 +599,6 @@ class TestMessageTrace:
|
|||||||
span = MagicMock()
|
span = MagicMock()
|
||||||
mock_tracing["start"].return_value = span
|
mock_tracing["start"].return_value = span
|
||||||
mock_tracing["set"].return_value = "token"
|
mock_tracing["set"].return_value = "token"
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
|
||||||
|
|
||||||
trace_instance.message_trace(_make_message_trace_info())
|
trace_instance.message_trace(_make_message_trace_info())
|
||||||
mock_tracing["start"].assert_called_once()
|
mock_tracing["start"].assert_called_once()
|
||||||
@ -609,7 +608,6 @@ class TestMessageTrace:
|
|||||||
span = MagicMock()
|
span = MagicMock()
|
||||||
mock_tracing["start"].return_value = span
|
mock_tracing["start"].return_value = span
|
||||||
mock_tracing["set"].return_value = "token"
|
mock_tracing["set"].return_value = "token"
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
|
||||||
|
|
||||||
trace_info = _make_message_trace_info(error="something broke")
|
trace_info = _make_message_trace_info(error="something broke")
|
||||||
trace_instance.message_trace(trace_info)
|
trace_instance.message_trace(trace_info)
|
||||||
@ -620,7 +618,6 @@ class TestMessageTrace:
|
|||||||
span = MagicMock()
|
span = MagicMock()
|
||||||
mock_tracing["start"].return_value = span
|
mock_tracing["start"].return_value = span
|
||||||
mock_tracing["set"].return_value = "token"
|
mock_tracing["set"].return_value = "token"
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
|
||||||
monkeypatch.setenv("FILES_URL", "http://files.test")
|
monkeypatch.setenv("FILES_URL", "http://files.test")
|
||||||
|
|
||||||
file_data = SimpleNamespace(url="path/to/file.png")
|
file_data = SimpleNamespace(url="path/to/file.png")
|
||||||
@ -638,7 +635,6 @@ class TestMessageTrace:
|
|||||||
span = MagicMock()
|
span = MagicMock()
|
||||||
mock_tracing["start"].return_value = span
|
mock_tracing["start"].return_value = span
|
||||||
mock_tracing["set"].return_value = "token"
|
mock_tracing["set"].return_value = "token"
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
|
||||||
|
|
||||||
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
|
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
|
||||||
trace_instance.message_trace(trace_info)
|
trace_instance.message_trace(trace_info)
|
||||||
@ -651,7 +647,6 @@ class TestMessageTrace:
|
|||||||
|
|
||||||
end_user = MagicMock()
|
end_user = MagicMock()
|
||||||
end_user.session_id = "session-xyz"
|
end_user.session_id = "session-xyz"
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
|
|
||||||
|
|
||||||
trace_info = _make_message_trace_info(
|
trace_info = _make_message_trace_info(
|
||||||
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
|
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
|
||||||
@ -664,7 +659,6 @@ class TestMessageTrace:
|
|||||||
span = MagicMock()
|
span = MagicMock()
|
||||||
mock_tracing["start"].return_value = span
|
mock_tracing["start"].return_value = span
|
||||||
mock_tracing["set"].return_value = "token"
|
mock_tracing["set"].return_value = "token"
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
|
||||||
|
|
||||||
trace_info = _make_message_trace_info(
|
trace_info = _make_message_trace_info(
|
||||||
metadata={"from_account_id": "acc-1"},
|
metadata={"from_account_id": "acc-1"},
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
|
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
|
||||||
@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
|
|||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
|
||||||
|
return cast(MagicMock, instance.add_trace)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
|
||||||
|
return cast(MagicMock, instance.add_span)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _seed_to_uuid4
|
# _seed_to_uuid4
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
|
|||||||
def test_root_span_is_created(self):
|
def test_root_span_is_created(self):
|
||||||
trace_info = _make_workflow_trace_info(message_id=None)
|
trace_info = _make_workflow_trace_info(message_id=None)
|
||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
assert instance.add_span.called
|
assert _add_span_mock(instance).called
|
||||||
|
|
||||||
def test_root_span_id_matches_expected(self):
|
def test_root_span_id_matches_expected(self):
|
||||||
trace_info = _make_workflow_trace_info(message_id=None)
|
trace_info = _make_workflow_trace_info(message_id=None)
|
||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
|
|
||||||
expected = self._expected_root_span_id(trace_info)
|
expected = self._expected_root_span_id(trace_info)
|
||||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||||
assert root_span_kwargs["id"] == expected
|
assert root_span_kwargs["id"] == expected
|
||||||
|
|
||||||
def test_root_span_has_no_parent(self):
|
def test_root_span_has_no_parent(self):
|
||||||
trace_info = _make_workflow_trace_info(message_id=None)
|
trace_info = _make_workflow_trace_info(message_id=None)
|
||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
|
|
||||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||||
assert root_span_kwargs["parent_span_id"] is None
|
assert root_span_kwargs["parent_span_id"] is None
|
||||||
|
|
||||||
def test_trace_name_is_workflow_trace(self):
|
def test_trace_name_is_workflow_trace(self):
|
||||||
@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
|
|||||||
trace_info = _make_workflow_trace_info(message_id=None)
|
trace_info = _make_workflow_trace_info(message_id=None)
|
||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
|
|
||||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||||
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||||
|
|
||||||
def test_root_span_name_is_workflow_trace(self):
|
def test_root_span_name_is_workflow_trace(self):
|
||||||
trace_info = _make_workflow_trace_info(message_id=None)
|
trace_info = _make_workflow_trace_info(message_id=None)
|
||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
|
|
||||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||||
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
|
||||||
|
|
||||||
def test_root_span_has_workflow_tag(self):
|
def test_root_span_has_workflow_tag(self):
|
||||||
trace_info = _make_workflow_trace_info(message_id=None)
|
trace_info = _make_workflow_trace_info(message_id=None)
|
||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
|
|
||||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||||
assert "workflow" in root_span_kwargs["tags"]
|
assert "workflow" in root_span_kwargs["tags"]
|
||||||
|
|
||||||
def test_node_execution_spans_are_parented_to_root(self):
|
def test_node_execution_spans_are_parented_to_root(self):
|
||||||
@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
|
|||||||
instance = self._run(trace_info, node_executions=[node_exec])
|
instance = self._run(trace_info, node_executions=[node_exec])
|
||||||
|
|
||||||
# call_args_list[0] = root span, [1] = node execution span
|
# call_args_list[0] = root span, [1] = node execution span
|
||||||
assert instance.add_span.call_count == 2
|
add_span = _add_span_mock(instance)
|
||||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
assert add_span.call_count == 2
|
||||||
|
node_span_kwargs = add_span.call_args_list[1][0][0]
|
||||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||||
|
|
||||||
def test_node_span_not_parented_to_workflow_app_log_id(self):
|
def test_node_span_not_parented_to_workflow_app_log_id(self):
|
||||||
@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
|
|||||||
instance = self._run(trace_info, node_executions=[node_exec])
|
instance = self._run(trace_info, node_executions=[node_exec])
|
||||||
|
|
||||||
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
|
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
|
||||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||||
assert node_span_kwargs["parent_span_id"] != old_parent_id
|
assert node_span_kwargs["parent_span_id"] != old_parent_id
|
||||||
|
|
||||||
def test_root_span_id_differs_from_trace_id(self):
|
def test_root_span_id_differs_from_trace_id(self):
|
||||||
@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
|
|||||||
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
|
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
|
||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
|
|
||||||
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
|
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
|
||||||
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
|
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
|
||||||
|
|
||||||
def test_root_span_uses_workflow_run_id_directly(self):
|
def test_root_span_uses_workflow_run_id_directly(self):
|
||||||
@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
|
|||||||
instance = self._run(trace_info)
|
instance = self._run(trace_info)
|
||||||
|
|
||||||
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||||
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
|
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
|
||||||
assert root_span_kwargs["id"] == expected_root_span_id
|
assert root_span_kwargs["id"] == expected_root_span_id
|
||||||
|
|
||||||
def test_root_span_id_differs_from_no_message_id_case(self):
|
def test_root_span_id_differs_from_no_message_id_case(self):
|
||||||
@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
|
|||||||
|
|
||||||
instance = self._run(trace_info, node_executions=[node_exec])
|
instance = self._run(trace_info, node_executions=[node_exec])
|
||||||
|
|
||||||
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
|
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
|
||||||
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, TypedDict, cast
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
|
|||||||
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||||
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
||||||
from opentelemetry.sdk.trace import Event
|
from opentelemetry.sdk.trace import Event
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
|
||||||
|
|
||||||
metric_reader_instances: list[DummyMetricReader] = []
|
metric_reader_instances: list[DummyMetricReader] = []
|
||||||
meter_provider_instances: list[DummyMeterProvider] = []
|
meter_provider_instances: list[DummyMeterProvider] = []
|
||||||
@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
|
|||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class PatchedCoreComponents(TypedDict):
|
||||||
|
span_exporter: MagicMock
|
||||||
|
span_processor: MagicMock
|
||||||
|
tracer: MagicMock
|
||||||
|
span: MagicMock
|
||||||
|
tracer_provider: MagicMock
|
||||||
|
logger: MagicMock
|
||||||
|
trace_api: Any
|
||||||
|
|
||||||
|
|
||||||
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
"""Drop fake metric modules into sys.modules so the client imports resolve."""
|
"""Drop fake metric modules into sys.modules so the client imports resolve."""
|
||||||
|
|
||||||
@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
|
||||||
span_exporter = MagicMock(name="span_exporter")
|
span_exporter = MagicMock(name="span_exporter")
|
||||||
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
|
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
|
||||||
|
|
||||||
@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
|
||||||
|
return SpanContext(
|
||||||
|
trace_id=trace_id,
|
||||||
|
span_id=span_id,
|
||||||
|
is_remote=False,
|
||||||
|
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_client() -> TencentTraceClient:
|
def _build_client() -> TencentTraceClient:
|
||||||
return TencentTraceClient(
|
return TencentTraceClient(
|
||||||
service_name="service",
|
service_name="service",
|
||||||
@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
|
|||||||
|
|
||||||
|
|
||||||
def test_resolve_grpc_target_handles_errors() -> None:
|
def test_resolve_grpc_target_handles_errors() -> None:
|
||||||
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
|
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
|
|||||||
client.record_trace_duration(0.5)
|
client.record_trace_duration(0.5)
|
||||||
|
|
||||||
|
|
||||||
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
|
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
|
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
|
||||||
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
|
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
|
||||||
@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
|
|||||||
logger.debug.assert_called()
|
logger.debug.assert_called()
|
||||||
|
|
||||||
|
|
||||||
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
|
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
span = patch_core_components["span"]
|
span = patch_core_components["span"]
|
||||||
span.get_span_context.return_value = "ctx"
|
ctx = _make_span_context(span_id=2)
|
||||||
|
span.get_span_context.return_value = ctx
|
||||||
|
|
||||||
data = SpanData(
|
data = SpanData(
|
||||||
trace_id=1,
|
trace_id=1,
|
||||||
@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
|
|||||||
span.add_event.assert_called_once()
|
span.add_event.assert_called_once()
|
||||||
span.set_status.assert_called_once()
|
span.set_status.assert_called_once()
|
||||||
span.end.assert_called_once_with(end_time=20)
|
span.end.assert_called_once_with(end_time=20)
|
||||||
assert client.span_contexts[2] == "ctx"
|
assert client.span_contexts[2] == ctx
|
||||||
|
|
||||||
|
|
||||||
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
|
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
client.span_contexts[10] = "existing"
|
existing_context = _make_span_context(span_id=10)
|
||||||
|
client.span_contexts[10] = existing_context
|
||||||
span = patch_core_components["span"]
|
span = patch_core_components["span"]
|
||||||
span.get_span_context.return_value = "child"
|
span.get_span_context.return_value = _make_span_context(span_id=11)
|
||||||
|
|
||||||
data = SpanData(
|
data = SpanData(
|
||||||
trace_id=1,
|
trace_id=1,
|
||||||
@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
|
|||||||
|
|
||||||
client._create_and_export_span(data)
|
client._create_and_export_span(data)
|
||||||
trace_api = patch_core_components["trace_api"]
|
trace_api = patch_core_components["trace_api"]
|
||||||
trace_api.NonRecordingSpan.assert_called_once_with("existing")
|
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
|
||||||
trace_api.set_span_in_context.assert_called_once()
|
trace_api.set_span_in_context.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
|
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
span = patch_core_components["span"]
|
span = patch_core_components["span"]
|
||||||
span.get_span_context.return_value = "ctx"
|
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||||
client.tracer.start_span.side_effect = RuntimeError("boom")
|
client.tracer.start_span.side_effect = RuntimeError("boom")
|
||||||
|
|
||||||
client._create_and_export_span(
|
client._create_and_export_span(
|
||||||
@ -385,7 +407,7 @@ def test_get_project_url() -> None:
|
|||||||
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
|
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
|
||||||
|
|
||||||
|
|
||||||
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
|
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
span_processor = patch_core_components["span_processor"]
|
span_processor = patch_core_components["span_processor"]
|
||||||
tracer_provider = patch_core_components["tracer_provider"]
|
tracer_provider = patch_core_components["tracer_provider"]
|
||||||
@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
|
|||||||
metric_reader.shutdown.assert_called_once()
|
metric_reader.shutdown.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
|
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
meter_provider = meter_provider_instances[-1]
|
meter_provider = meter_provider_instances[-1]
|
||||||
meter_provider.shutdown.side_effect = RuntimeError("boom")
|
meter_provider.shutdown.side_effect = RuntimeError("boom")
|
||||||
|
assert client.metric_reader is not None
|
||||||
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
|
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
|
||||||
|
|
||||||
client.shutdown()
|
client.shutdown()
|
||||||
@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
|
|||||||
assert client.metric_reader is None
|
assert client.metric_reader is None
|
||||||
|
|
||||||
|
|
||||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
|
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
|
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||||
|
|
||||||
@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
|
|||||||
logger.exception.assert_called_once()
|
logger.exception.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
|
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
span = patch_core_components["span"]
|
span = patch_core_components["span"]
|
||||||
span.get_span_context.return_value = "ctx"
|
span.get_span_context.return_value = _make_span_context(span_id=2)
|
||||||
|
|
||||||
data = SpanData.model_construct(
|
data = SpanData.model_construct(
|
||||||
trace_id=1,
|
trace_id=1,
|
||||||
@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
|
|||||||
hist_mock = MagicMock(name="hist_llm_duration")
|
hist_mock = MagicMock(name="hist_llm_duration")
|
||||||
client.hist_llm_duration = hist_mock
|
client.hist_llm_duration = hist_mock
|
||||||
|
|
||||||
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
|
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
|
||||||
_, attrs = hist_mock.record.call_args.args
|
_, attrs = hist_mock.record.call_args.args
|
||||||
assert isinstance(attrs["foo"], str)
|
assert isinstance(attrs["foo"], str)
|
||||||
assert attrs["bar"] == 2
|
assert attrs["bar"] == 2
|
||||||
@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
|||||||
hist_mock = MagicMock(name="hist_trace_duration")
|
hist_mock = MagicMock(name="hist_trace_duration")
|
||||||
client.hist_trace_duration = hist_mock
|
client.hist_trace_duration = hist_mock
|
||||||
|
|
||||||
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
|
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
|
||||||
_, attrs = hist_mock.record.call_args.args
|
_, attrs = hist_mock.record.call_args.args
|
||||||
assert isinstance(attrs["meta"], str)
|
assert isinstance(attrs["meta"], str)
|
||||||
assert attrs["ok"] is True
|
assert attrs["ok"] is True
|
||||||
@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_record_methods_handle_exceptions(
|
def test_record_methods_handle_exceptions(
|
||||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
|
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
|
||||||
) -> None:
|
) -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
hist_mock = MagicMock(name=attr_name)
|
hist_mock = MagicMock(name=attr_name)
|
||||||
@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
|
|||||||
def test_metrics_initializes_grpc_metric_exporter() -> None:
|
def test_metrics_initializes_grpc_metric_exporter() -> None:
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
metric_reader = metric_reader_instances[-1]
|
metric_reader = metric_reader_instances[-1]
|
||||||
|
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
|
||||||
|
|
||||||
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
|
assert isinstance(exporter, DummyGrpcMetricExporter)
|
||||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||||
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||||
assert metric_reader.exporter.kwargs["insecure"] is False
|
assert exporter.kwargs["insecure"] is False
|
||||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||||
|
|
||||||
|
|
||||||
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
|
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
metric_reader = metric_reader_instances[-1]
|
metric_reader = metric_reader_instances[-1]
|
||||||
|
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
|
||||||
|
|
||||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
assert isinstance(exporter, DummyHttpMetricExporter)
|
||||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||||
|
|
||||||
|
|
||||||
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||||
client = _build_client()
|
client = _build_client()
|
||||||
metric_reader = metric_reader_instances[-1]
|
metric_reader = metric_reader_instances[-1]
|
||||||
|
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
|
||||||
|
|
||||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
|
assert isinstance(exporter, DummyJsonMetricExporter)
|
||||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
assert exporter.kwargs["endpoint"] == client.endpoint
|
||||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
|
||||||
assert "preferred_temporality" in metric_reader.exporter.kwargs
|
assert "preferred_temporality" in exporter.kwargs
|
||||||
|
|
||||||
|
|
||||||
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
|
|||||||
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
|
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
|
||||||
_ = _build_client()
|
_ = _build_client()
|
||||||
metric_reader = metric_reader_instances[-1]
|
metric_reader = metric_reader_instances[-1]
|
||||||
|
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
|
||||||
|
|
||||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
|
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
|
||||||
assert "preferred_temporality" not in metric_reader.exporter.kwargs
|
assert "preferred_temporality" not in exporter.kwargs
|
||||||
|
|
||||||
|
|
||||||
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
|||||||
@ -31,13 +31,13 @@ class TestWeaveConfig:
|
|||||||
def test_missing_required_fields(self):
|
def test_missing_required_fields(self):
|
||||||
"""Test that required fields are enforced"""
|
"""Test that required fields are enforced"""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
WeaveConfig()
|
WeaveConfig.model_validate({})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
WeaveConfig(api_key="key")
|
WeaveConfig.model_validate({"api_key": "key"})
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
WeaveConfig(project="project")
|
WeaveConfig.model_validate({"project": "project"})
|
||||||
|
|
||||||
def test_endpoint_validation_https_only(self):
|
def test_endpoint_validation_https_only(self):
|
||||||
"""Test endpoint validation only allows HTTPS"""
|
"""Test endpoint validation only allows HTTPS"""
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
|
|||||||
|
|
||||||
auth = PasswordAuthenticator(config.user, config.password)
|
auth = PasswordAuthenticator(config.user, config.password)
|
||||||
options = ClusterOptions(auth)
|
options = ClusterOptions(auth)
|
||||||
self._cluster = Cluster(config.connection_string, options)
|
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
|
||||||
self._bucket = self._cluster.bucket(config.bucket_name)
|
self._bucket = self._cluster.bucket(config.bucket_name)
|
||||||
self._scope = self._bucket.scope(config.scope_name)
|
self._scope = self._bucket.scope(config.scope_name)
|
||||||
self._bucket_name = config.bucket_name
|
self._bucket_name = config.bucket_name
|
||||||
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
|||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
try:
|
try:
|
||||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
|
||||||
search_iter = self._scope.search(
|
search_iter = self._scope.search(
|
||||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
|
|||||||
def _load_collection_fields(self, fields: list[str] | None = None):
|
def _load_collection_fields(self, fields: list[str] | None = None):
|
||||||
if fields is None:
|
if fields is None:
|
||||||
# Load collection fields from remote server
|
# Load collection fields from remote server
|
||||||
collection_info = self._client.describe_collection(self._collection_name)
|
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
|
||||||
fields = [field["name"] for field in collection_info["fields"]]
|
fields = [field["name"] for field in collection_info["fields"]]
|
||||||
# Since primary field is auto-id, no need to track it
|
# Since primary field is auto-id, no need to track it
|
||||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
||||||
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
milvus_version = self._client.get_server_version()
|
milvus_version_raw = self._client.get_server_version()
|
||||||
|
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
|
||||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||||
if "Zilliz Cloud" in milvus_version:
|
if "Zilliz Cloud" in milvus_version:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import jieba.posseg as pseg # type: ignore
|
import jieba.posseg as pseg # type: ignore
|
||||||
import numpy
|
import numpy
|
||||||
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
|
|||||||
oracledb.defaults.fetch_lobs = False
|
oracledb.defaults.fetch_lobs = False
|
||||||
|
|
||||||
|
|
||||||
|
class _OraclePoolParams(TypedDict, total=False):
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
dsn: str
|
||||||
|
min: int
|
||||||
|
max: int
|
||||||
|
increment: int
|
||||||
|
config_dir: str | None
|
||||||
|
wallet_location: str | None
|
||||||
|
wallet_password: str | None
|
||||||
|
|
||||||
|
|
||||||
class OracleVectorConfig(BaseModel):
|
class OracleVectorConfig(BaseModel):
|
||||||
user: str
|
user: str
|
||||||
password: str
|
password: str
|
||||||
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
|
|||||||
return connection
|
return connection
|
||||||
|
|
||||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||||
pool_params = {
|
pool_params = _OraclePoolParams(
|
||||||
"user": config.user,
|
user=config.user,
|
||||||
"password": config.password,
|
password=config.password,
|
||||||
"dsn": config.dsn,
|
dsn=config.dsn,
|
||||||
"min": 1,
|
min=1,
|
||||||
"max": 5,
|
max=5,
|
||||||
"increment": 1,
|
increment=1,
|
||||||
}
|
)
|
||||||
if config.is_autonomous:
|
if config.is_autonomous:
|
||||||
pool_params.update(
|
pool_params["config_dir"] = config.config_dir
|
||||||
{
|
pool_params["wallet_location"] = config.wallet_location
|
||||||
"config_dir": config.config_dir,
|
pool_params["wallet_password"] = config.wallet_password
|
||||||
"wallet_location": config.wallet_location,
|
|
||||||
"wallet_password": config.wallet_password,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return oracledb.create_pool(**pool_params)
|
return oracledb.create_pool(**pool_params)
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
|||||||
@ -9,6 +9,7 @@ dependencies = [
|
|||||||
"boto3>=1.42.91",
|
"boto3>=1.42.91",
|
||||||
"celery>=5.6.3",
|
"celery>=5.6.3",
|
||||||
"croniter>=6.2.2",
|
"croniter>=6.2.2",
|
||||||
|
"flask>=3.1.3,<4.0.0",
|
||||||
"flask-cors>=6.0.2",
|
"flask-cors>=6.0.2",
|
||||||
"gevent>=26.4.0",
|
"gevent>=26.4.0",
|
||||||
"gevent-websocket>=0.10.1",
|
"gevent-websocket>=0.10.1",
|
||||||
@ -117,7 +118,7 @@ dev = [
|
|||||||
"faker>=40.15.0",
|
"faker>=40.15.0",
|
||||||
"lxml-stubs>=0.5.1",
|
"lxml-stubs>=0.5.1",
|
||||||
"basedpyright>=1.39.3",
|
"basedpyright>=1.39.3",
|
||||||
"ruff>=0.15.11",
|
"ruff>=0.15.12",
|
||||||
"pytest>=9.0.3",
|
"pytest>=9.0.3",
|
||||||
"pytest-benchmark>=5.2.3",
|
"pytest-benchmark>=5.2.3",
|
||||||
"pytest-cov>=7.1.0",
|
"pytest-cov>=7.1.0",
|
||||||
@ -144,7 +145,7 @@ dev = [
|
|||||||
"types-pexpect>=4.9.0",
|
"types-pexpect>=4.9.0",
|
||||||
"types-protobuf>=7.34.1",
|
"types-protobuf>=7.34.1",
|
||||||
"types-psutil>=7.2.2",
|
"types-psutil>=7.2.2",
|
||||||
"types-psycopg2>=2.9.21",
|
"types-psycopg2>=2.9.21.20260422",
|
||||||
"types-pygments>=2.20.0",
|
"types-pygments>=2.20.0",
|
||||||
"types-pymysql>=1.1.0",
|
"types-pymysql>=1.1.0",
|
||||||
"types-python-dateutil>=2.9.0",
|
"types-python-dateutil>=2.9.0",
|
||||||
@ -157,9 +158,9 @@ dev = [
|
|||||||
"types-tensorflow>=2.18.0.20260408",
|
"types-tensorflow>=2.18.0.20260408",
|
||||||
"types-tqdm>=4.67.3.20260408",
|
"types-tqdm>=4.67.3.20260408",
|
||||||
"types-ujson>=5.10.0",
|
"types-ujson>=5.10.0",
|
||||||
"boto3-stubs>=1.42.92",
|
"boto3-stubs>=1.42.96",
|
||||||
"types-jmespath>=1.1.0.20260408",
|
"types-jmespath>=1.1.0.20260408",
|
||||||
"hypothesis>=6.152.1",
|
"hypothesis>=6.152.3",
|
||||||
"types_pyOpenSSL>=24.1.0",
|
"types_pyOpenSSL>=24.1.0",
|
||||||
"types_cffi>=2.0.0.20260408",
|
"types_cffi>=2.0.0.20260408",
|
||||||
"types_setuptools>=82.0.0.20260408",
|
"types_setuptools>=82.0.0.20260408",
|
||||||
@ -169,7 +170,7 @@ dev = [
|
|||||||
"import-linter>=2.3",
|
"import-linter>=2.3",
|
||||||
"types-redis>=4.6.0.20241004",
|
"types-redis>=4.6.0.20241004",
|
||||||
"celery-types>=0.23.0",
|
"celery-types>=0.23.0",
|
||||||
"mypy>=1.20.1",
|
"mypy>=1.20.2",
|
||||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||||
"pytest-timeout>=2.4.0",
|
"pytest-timeout>=2.4.0",
|
||||||
"pytest-xdist>=3.8.0",
|
"pytest-xdist>=3.8.0",
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
|
|||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from libs.time_parser import get_time_threshold
|
from libs.time_parser import get_time_threshold
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.human_input import HumanInputForm
|
from models.human_input import HumanInputForm, HumanInputFormRecipient
|
||||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
|
||||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||||
@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
|
|||||||
def _build_human_input_required_reason(
|
def _build_human_input_required_reason(
|
||||||
reason_model: WorkflowPauseReason,
|
reason_model: WorkflowPauseReason,
|
||||||
form_model: HumanInputForm | None,
|
form_model: HumanInputForm | None,
|
||||||
|
recipients: Sequence[HumanInputFormRecipient] = (),
|
||||||
) -> HumanInputRequired:
|
) -> HumanInputRequired:
|
||||||
form_content = ""
|
form_content = ""
|
||||||
inputs = []
|
inputs = []
|
||||||
@ -89,7 +90,7 @@ def _build_human_input_required_reason(
|
|||||||
resolved_default_values = dict(definition.default_values)
|
resolved_default_values = dict(definition.default_values)
|
||||||
node_title = definition.node_title or node_title
|
node_title = definition.node_title or node_title
|
||||||
|
|
||||||
return HumanInputRequired(
|
reason = HumanInputRequired(
|
||||||
form_id=form_id,
|
form_id=form_id,
|
||||||
form_content=form_content,
|
form_content=form_content,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@ -98,6 +99,7 @@ def _build_human_input_required_reason(
|
|||||||
node_title=node_title,
|
node_title=node_title,
|
||||||
resolved_default_values=resolved_default_values,
|
resolved_default_values=resolved_default_values,
|
||||||
)
|
)
|
||||||
|
return reason
|
||||||
|
|
||||||
|
|
||||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||||
@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
|
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
|
||||||
for form in session.scalars(form_stmt).all():
|
for form in session.scalars(form_stmt).all():
|
||||||
form_models[form.id] = form
|
form_models[form.id] = form
|
||||||
|
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
|
||||||
|
if form_ids:
|
||||||
|
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||||
|
for recipient in session.scalars(recipient_stmt).all():
|
||||||
|
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
|
||||||
|
|
||||||
pause_reasons: list[PauseReason] = []
|
pause_reasons: list[PauseReason] = []
|
||||||
for reason in pause_reason_models:
|
for reason in pause_reason_models:
|
||||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||||
form_model = form_models.get(reason.form_id)
|
form_model = form_models.get(reason.form_id)
|
||||||
pause_reasons.append(_build_human_input_required_reason(reason, form_model))
|
pause_reasons.append(
|
||||||
|
_build_human_input_required_reason(
|
||||||
|
reason,
|
||||||
|
form_model,
|
||||||
|
recipients_by_form_id.get(reason.form_id, ()),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pause_reasons.append(reason.to_entity())
|
pause_reasons.append(reason.to_entity())
|
||||||
return pause_reasons
|
return pause_reasons
|
||||||
|
|||||||
@ -133,7 +133,14 @@ class AppAnnotationService:
|
|||||||
raise ValueError("'question' is required when 'message_id' is not provided")
|
raise ValueError("'question' is required when 'message_id' is not provided")
|
||||||
question = maybe_question
|
question = maybe_question
|
||||||
|
|
||||||
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
|
annotation = MessageAnnotation(
|
||||||
|
app_id=app.id,
|
||||||
|
conversation_id=None,
|
||||||
|
message_id=None,
|
||||||
|
content=answer,
|
||||||
|
question=question,
|
||||||
|
account_id=current_user.id,
|
||||||
|
)
|
||||||
db.session.add(annotation)
|
db.session.add(annotation)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|||||||
@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
|
|||||||
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||||
from core.db import session_factory
|
from core.db import session_factory
|
||||||
from enums.quota_type import QuotaType, unlimited
|
from enums.quota_type import QuotaType
|
||||||
from extensions.otel import AppGenerateHandler, trace_span
|
from extensions.otel import AppGenerateHandler, trace_span
|
||||||
from models.model import Account, App, AppMode, EndUser
|
from models.model import Account, App, AppMode, EndUser
|
||||||
from models.workflow import Workflow, WorkflowRun
|
from models.workflow import Workflow, WorkflowRun
|
||||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
from services.quota_service import QuotaService, unlimited
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
||||||
|
|
||||||
@ -106,7 +107,7 @@ class AppGenerateService:
|
|||||||
quota_charge = unlimited()
|
quota_charge = unlimited()
|
||||||
if dify_config.BILLING_ENABLED:
|
if dify_config.BILLING_ENABLED:
|
||||||
try:
|
try:
|
||||||
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
|
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
|
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
|
||||||
|
|
||||||
@ -116,6 +117,7 @@ class AppGenerateService:
|
|||||||
request_id = RateLimit.gen_request_key()
|
request_id = RateLimit.gen_request_key()
|
||||||
try:
|
try:
|
||||||
request_id = rate_limit.enter(request_id)
|
request_id = rate_limit.enter(request_id)
|
||||||
|
quota_charge.commit()
|
||||||
effective_mode = (
|
effective_mode = (
|
||||||
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
||||||
)
|
)
|
||||||
@ -162,6 +164,7 @@ class AppGenerateService:
|
|||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
call_depth=0,
|
call_depth=0,
|
||||||
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
payload_json = payload.model_dump_json()
|
payload_json = payload.model_dump_json()
|
||||||
|
|
||||||
@ -183,6 +186,10 @@ class AppGenerateService:
|
|||||||
else:
|
else:
|
||||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||||
|
pause_config = PauseStateLayerConfig(
|
||||||
|
session_factory=session_factory.get_session_maker(),
|
||||||
|
state_owner_user_id=workflow.created_by,
|
||||||
|
)
|
||||||
advanced_generator = AdvancedChatAppGenerator()
|
advanced_generator = AdvancedChatAppGenerator()
|
||||||
return rate_limit.generate(
|
return rate_limit.generate(
|
||||||
advanced_generator.convert_to_event_stream(
|
advanced_generator.convert_to_event_stream(
|
||||||
@ -194,6 +201,7 @@ class AppGenerateService:
|
|||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
workflow_run_id=str(uuid.uuid4()),
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
streaming=False,
|
streaming=False,
|
||||||
|
pause_state_config=pause_config,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
|||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||||
|
from services.quota_service import QuotaService, unlimited
|
||||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
@ -88,7 +89,10 @@ class AsyncWorkflowService:
|
|||||||
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
|
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
|
||||||
|
|
||||||
# 2. Get workflow
|
# 2. Get workflow
|
||||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
|
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
|
||||||
|
|
||||||
|
# commit read only session before starting the billig rpc call
|
||||||
|
session.commit()
|
||||||
|
|
||||||
# 3. Get dispatcher based on tenant subscription
|
# 3. Get dispatcher based on tenant subscription
|
||||||
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
||||||
@ -131,9 +135,10 @@ class AsyncWorkflowService:
|
|||||||
trigger_log = trigger_log_repo.create(trigger_log)
|
trigger_log = trigger_log_repo.create(trigger_log)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# 7. Check and consume quota
|
# 7. Reserve quota (commit after successful dispatch)
|
||||||
|
quota_charge = unlimited()
|
||||||
try:
|
try:
|
||||||
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
|
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
|
||||||
except QuotaExceededError as e:
|
except QuotaExceededError as e:
|
||||||
# Update trigger log status
|
# Update trigger log status
|
||||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||||
@ -153,13 +158,18 @@ class AsyncWorkflowService:
|
|||||||
# 9. Dispatch to appropriate queue
|
# 9. Dispatch to appropriate queue
|
||||||
task_data_dict = task_data.model_dump(mode="json")
|
task_data_dict = task_data.model_dump(mode="json")
|
||||||
|
|
||||||
task: AsyncResult[Any] | None = None
|
try:
|
||||||
if queue_name == QueuePriority.PROFESSIONAL:
|
task: AsyncResult[Any] | None = None
|
||||||
task = execute_workflow_professional.delay(task_data_dict)
|
if queue_name == QueuePriority.PROFESSIONAL:
|
||||||
elif queue_name == QueuePriority.TEAM:
|
task = execute_workflow_professional.delay(task_data_dict)
|
||||||
task = execute_workflow_team.delay(task_data_dict)
|
elif queue_name == QueuePriority.TEAM:
|
||||||
else: # SANDBOX
|
task = execute_workflow_team.delay(task_data_dict)
|
||||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
else: # SANDBOX
|
||||||
|
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||||
|
quota_charge.commit()
|
||||||
|
except Exception:
|
||||||
|
quota_charge.refund()
|
||||||
|
raise
|
||||||
|
|
||||||
# 10. Update trigger log with task info
|
# 10. Update trigger log with task info
|
||||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||||
@ -295,13 +305,21 @@ class AsyncWorkflowService:
|
|||||||
return [log.to_dict() for log in logs]
|
return [log.to_dict() for log in logs]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
|
def _get_workflow(
|
||||||
|
workflow_service: WorkflowService,
|
||||||
|
app_model: App,
|
||||||
|
workflow_id: str | None = None,
|
||||||
|
session: Session | None = None,
|
||||||
|
) -> Workflow:
|
||||||
"""
|
"""
|
||||||
Get workflow for the app
|
Get workflow for the app
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app_model: App model instance
|
app_model: App model instance
|
||||||
workflow_id: Optional specific workflow ID
|
workflow_id: Optional specific workflow ID
|
||||||
|
session: Reuse this SQLAlchemy session for the lookup when provided,
|
||||||
|
so the caller's explicit session bears the connection cost
|
||||||
|
instead of Flask's request-scoped ``db.session``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Workflow instance
|
Workflow instance
|
||||||
@ -311,12 +329,12 @@ class AsyncWorkflowService:
|
|||||||
"""
|
"""
|
||||||
if workflow_id:
|
if workflow_id:
|
||||||
# Get specific published workflow
|
# Get specific published workflow
|
||||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
|
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
|
||||||
if not workflow:
|
if not workflow:
|
||||||
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
|
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
|
||||||
else:
|
else:
|
||||||
# Get default published workflow
|
# Get default published workflow
|
||||||
workflow = workflow_service.get_published_workflow(app_model)
|
workflow = workflow_service.get_published_workflow(app_model, session=session)
|
||||||
if not workflow:
|
if not workflow:
|
||||||
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
|
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict):
|
|||||||
expiration_date: int
|
expiration_date: int
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaReserveResult(TypedDict):
|
||||||
|
reservation_id: str
|
||||||
|
available: int
|
||||||
|
reserved: int
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaCommitResult(TypedDict):
|
||||||
|
available: int
|
||||||
|
reserved: int
|
||||||
|
refunded: int
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaReleaseResult(TypedDict):
|
||||||
|
available: int
|
||||||
|
reserved: int
|
||||||
|
released: int
|
||||||
|
|
||||||
|
|
||||||
|
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
|
||||||
|
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
|
||||||
|
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
|
||||||
|
|
||||||
|
|
||||||
|
class _TenantFeatureQuota(TypedDict):
|
||||||
|
usage: int
|
||||||
|
limit: int
|
||||||
|
reset_date: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
|
class TenantFeatureQuotaInfo(TypedDict):
|
||||||
|
"""Response of /quota/info.
|
||||||
|
|
||||||
|
NOTE (hj24):
|
||||||
|
- Same convention as BillingInfo: billing may return int fields as str,
|
||||||
|
always keep non-strict mode to auto-coerce.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trigger_event: _TenantFeatureQuota
|
||||||
|
api_rate_limit: _TenantFeatureQuota
|
||||||
|
|
||||||
|
|
||||||
|
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
|
||||||
|
|
||||||
|
|
||||||
class _BillingQuota(TypedDict):
|
class _BillingQuota(TypedDict):
|
||||||
size: int
|
size: int
|
||||||
limit: int
|
limit: int
|
||||||
@ -149,11 +193,63 @@ class BillingService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||||
|
"""Deprecated: Use get_quota_info instead."""
|
||||||
params = {"tenant_id": tenant_id}
|
params = {"tenant_id": tenant_id}
|
||||||
|
|
||||||
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
|
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
|
||||||
return usage_info
|
return usage_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
|
||||||
|
params = {"tenant_id": tenant_id}
|
||||||
|
return _tenant_feature_quota_info_adapter.validate_python(
|
||||||
|
cls._send_request("GET", "/quota/info", params=params)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quota_reserve(
|
||||||
|
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
|
||||||
|
) -> QuotaReserveResult:
|
||||||
|
"""Reserve quota before task execution."""
|
||||||
|
payload: dict = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"feature_key": feature_key,
|
||||||
|
"request_id": request_id,
|
||||||
|
"amount": amount,
|
||||||
|
}
|
||||||
|
if meta:
|
||||||
|
payload["meta"] = meta
|
||||||
|
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quota_commit(
|
||||||
|
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
|
||||||
|
) -> QuotaCommitResult:
|
||||||
|
"""Commit a reservation with actual consumption."""
|
||||||
|
payload: dict = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"feature_key": feature_key,
|
||||||
|
"reservation_id": reservation_id,
|
||||||
|
"actual_amount": actual_amount,
|
||||||
|
}
|
||||||
|
if meta:
|
||||||
|
payload["meta"] = meta
|
||||||
|
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
|
||||||
|
"""Release a reservation (cancel, return frozen quota)."""
|
||||||
|
return _quota_release_adapter.validate_python(
|
||||||
|
cls._send_request(
|
||||||
|
"POST",
|
||||||
|
"/quota/release",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"feature_key": feature_key,
|
||||||
|
"reservation_id": reservation_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
||||||
params = {"tenant_id": tenant_id}
|
params = {"tenant_id": tenant_id}
|
||||||
|
|||||||
@ -3748,6 +3748,7 @@ class SegmentService:
|
|||||||
ChildChunk.segment_id == segment.id,
|
ChildChunk.segment_id == segment.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert current_user.current_tenant_id
|
||||||
child_chunk = ChildChunk(
|
child_chunk = ChildChunk(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
@ -3758,7 +3759,7 @@ class SegmentService:
|
|||||||
index_node_hash=index_node_hash,
|
index_node_hash=index_node_hash,
|
||||||
content=content,
|
content=content,
|
||||||
word_count=len(content),
|
word_count=len(content),
|
||||||
type="customized",
|
type=SegmentType.CUSTOMIZED,
|
||||||
created_by=current_user.id,
|
created_by=current_user.id,
|
||||||
)
|
)
|
||||||
db.session.add(child_chunk)
|
db.session.add(child_chunk)
|
||||||
@ -3818,6 +3819,7 @@ class SegmentService:
|
|||||||
if new_child_chunks_args:
|
if new_child_chunks_args:
|
||||||
child_chunk_count = len(child_chunks)
|
child_chunk_count = len(child_chunks)
|
||||||
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
|
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
|
||||||
|
assert current_user.current_tenant_id
|
||||||
index_node_id = str(uuid.uuid4())
|
index_node_id = str(uuid.uuid4())
|
||||||
index_node_hash = helper.generate_text_hash(args.content)
|
index_node_hash = helper.generate_text_hash(args.content)
|
||||||
child_chunk = ChildChunk(
|
child_chunk = ChildChunk(
|
||||||
@ -3830,7 +3832,7 @@ class SegmentService:
|
|||||||
index_node_hash=index_node_hash,
|
index_node_hash=index_node_hash,
|
||||||
content=args.content,
|
content=args.content,
|
||||||
word_count=len(args.content),
|
word_count=len(args.content),
|
||||||
type="customized",
|
type=SegmentType.CUSTOMIZED,
|
||||||
created_by=current_user.id,
|
created_by=current_user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cachetools.func import ttl_cache
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
|
|||||||
|
|
||||||
class EnterpriseService:
|
class EnterpriseService:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ttl_cache(ttl=5)
|
||||||
def get_info(cls):
|
def get_info(cls):
|
||||||
return EnterpriseRequest.send_request("GET", "/info")
|
return EnterpriseRequest.send_request("GET", "/info")
|
||||||
|
|
||||||
|
|||||||
@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
|
|||||||
enable_change_email: bool = True
|
enable_change_email: bool = True
|
||||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||||
trial_models: list[str] = []
|
trial_models: list[str] = []
|
||||||
|
enable_creators_platform: bool = False
|
||||||
enable_trial_app: bool = False
|
enable_trial_app: bool = False
|
||||||
enable_explore_banner: bool = False
|
enable_explore_banner: bool = False
|
||||||
|
|
||||||
@ -241,6 +242,9 @@ class FeatureService:
|
|||||||
if dify_config.MARKETPLACE_ENABLED:
|
if dify_config.MARKETPLACE_ENABLED:
|
||||||
system_features.enable_marketplace = True
|
system_features.enable_marketplace = True
|
||||||
|
|
||||||
|
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||||
|
system_features.enable_creators_platform = True
|
||||||
|
|
||||||
return system_features
|
return system_features
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -286,7 +290,7 @@ class FeatureService:
|
|||||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||||
billing_info = BillingService.get_info(tenant_id)
|
billing_info = BillingService.get_info(tenant_id)
|
||||||
|
|
||||||
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
|
features_usage_info = BillingService.get_quota_info(tenant_id)
|
||||||
|
|
||||||
features.billing.enabled = billing_info["enabled"]
|
features.billing.enabled = billing_info["enabled"]
|
||||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||||
|
|||||||
233
api/services/quota_service.py
Normal file
233
api/services/quota_service.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from enums.quota_type import QuotaType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuotaCharge:
|
||||||
|
"""
|
||||||
|
Result of a quota reservation (Reserve phase).
|
||||||
|
|
||||||
|
Lifecycle:
|
||||||
|
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
|
||||||
|
try:
|
||||||
|
do_work()
|
||||||
|
charge.commit() # Confirm consumption
|
||||||
|
except:
|
||||||
|
charge.refund() # Release frozen quota
|
||||||
|
|
||||||
|
If neither commit() nor refund() is called, the billing system's
|
||||||
|
cleanup CronJob will auto-release the reservation within ~75 seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
charge_id: str | None # reservation_id
|
||||||
|
_quota_type: QuotaType
|
||||||
|
_tenant_id: str | None = None
|
||||||
|
_feature_key: str | None = None
|
||||||
|
_amount: int = 0
|
||||||
|
_committed: bool = field(default=False, repr=False)
|
||||||
|
|
||||||
|
def commit(self, actual_amount: int | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Confirm the consumption with actual amount.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual_amount: Actual amount consumed. Defaults to the reserved amount.
|
||||||
|
If less than reserved, the difference is refunded automatically.
|
||||||
|
"""
|
||||||
|
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
amount = actual_amount if actual_amount is not None else self._amount
|
||||||
|
BillingService.quota_commit(
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
feature_key=self._feature_key,
|
||||||
|
reservation_id=self.charge_id,
|
||||||
|
actual_amount=amount,
|
||||||
|
)
|
||||||
|
self._committed = True
|
||||||
|
logger.debug(
|
||||||
|
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
|
||||||
|
self._quota_type,
|
||||||
|
self._tenant_id,
|
||||||
|
self.charge_id,
|
||||||
|
amount,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
|
||||||
|
|
||||||
|
def refund(self) -> None:
|
||||||
|
"""
|
||||||
|
Release the reserved quota (cancel the charge).
|
||||||
|
|
||||||
|
Safe to call even if:
|
||||||
|
- charge failed or was disabled (charge_id is None)
|
||||||
|
- already committed (Release after Commit is a no-op)
|
||||||
|
- already refunded (idempotent)
|
||||||
|
|
||||||
|
This method guarantees no exceptions will be raised.
|
||||||
|
"""
|
||||||
|
if not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||||
|
return
|
||||||
|
|
||||||
|
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
|
||||||
|
|
||||||
|
|
||||||
|
def unlimited() -> QuotaCharge:
|
||||||
|
from enums.quota_type import QuotaType
|
||||||
|
|
||||||
|
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaService:
|
||||||
|
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||||
|
"""
|
||||||
|
Reserve + immediate Commit (one-shot mode).
|
||||||
|
|
||||||
|
The returned QuotaCharge supports .refund() which calls Release.
|
||||||
|
For two-phase usage (e.g. streaming), use reserve() directly.
|
||||||
|
"""
|
||||||
|
charge = QuotaService.reserve(quota_type, tenant_id, amount)
|
||||||
|
if charge.success and charge.charge_id:
|
||||||
|
charge.commit()
|
||||||
|
return charge
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||||
|
"""
|
||||||
|
Reserve quota before task execution (Reserve phase only).
|
||||||
|
|
||||||
|
The caller MUST call charge.commit() after the task succeeds,
|
||||||
|
or charge.refund() if the task fails.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QuotaExceededError: When quota is insufficient
|
||||||
|
"""
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
from services.errors.app import QuotaExceededError
|
||||||
|
|
||||||
|
if not dify_config.BILLING_ENABLED:
|
||||||
|
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||||
|
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
|
||||||
|
|
||||||
|
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
|
||||||
|
|
||||||
|
if amount <= 0:
|
||||||
|
raise ValueError("Amount to reserve must be greater than 0")
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
feature_key = quota_type.billing_key
|
||||||
|
|
||||||
|
try:
|
||||||
|
reserve_resp = BillingService.quota_reserve(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
feature_key=feature_key,
|
||||||
|
request_id=request_id,
|
||||||
|
amount=amount,
|
||||||
|
)
|
||||||
|
|
||||||
|
reservation_id = reserve_resp.get("reservation_id")
|
||||||
|
if not reservation_id:
|
||||||
|
logger.warning(
|
||||||
|
"Reserve returned no reservation_id for %s, feature %s, response: %s",
|
||||||
|
tenant_id,
|
||||||
|
quota_type.value,
|
||||||
|
reserve_resp,
|
||||||
|
)
|
||||||
|
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Reserved %d %s quota for tenant %s, reservation_id: %s",
|
||||||
|
amount,
|
||||||
|
quota_type.value,
|
||||||
|
tenant_id,
|
||||||
|
reservation_id,
|
||||||
|
)
|
||||||
|
return QuotaCharge(
|
||||||
|
success=True,
|
||||||
|
charge_id=reservation_id,
|
||||||
|
_quota_type=quota_type,
|
||||||
|
_tenant_id=tenant_id,
|
||||||
|
_feature_key=feature_key,
|
||||||
|
_amount=amount,
|
||||||
|
)
|
||||||
|
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
|
||||||
|
return unlimited()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
|
||||||
|
if not dify_config.BILLING_ENABLED:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if amount <= 0:
|
||||||
|
raise ValueError("Amount to check must be greater than 0")
|
||||||
|
|
||||||
|
try:
|
||||||
|
remaining = QuotaService.get_remaining(quota_type, tenant_id)
|
||||||
|
return remaining >= amount if remaining != -1 else True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
|
||||||
|
"""Release a reservation. Guarantees no exceptions."""
|
||||||
|
try:
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
if not dify_config.BILLING_ENABLED:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not reservation_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
|
||||||
|
BillingService.quota_release(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
feature_key=feature_key,
|
||||||
|
reservation_id=reservation_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
try:
|
||||||
|
usage_info = BillingService.get_quota_info(tenant_id)
|
||||||
|
if isinstance(usage_info, dict):
|
||||||
|
feature_info = usage_info.get(quota_type.billing_key, {})
|
||||||
|
if isinstance(feature_info, dict):
|
||||||
|
limit = feature_info.get("limit", 0)
|
||||||
|
usage = feature_info.get("usage", 0)
|
||||||
|
if limit == -1:
|
||||||
|
return -1
|
||||||
|
return max(0, limit - usage)
|
||||||
|
return 0
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
|
||||||
|
return -1
|
||||||
@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
|
|||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.encryption import create_provider_encrypter
|
from core.tools.utils.encryption import create_provider_encrypter
|
||||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
from core.tools.utils.system_encryption import decrypt_system_params
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.provider_ids import ToolProviderID
|
from models.provider_ids import ToolProviderID
|
||||||
@ -521,7 +521,7 @@ class BuiltinToolManageService:
|
|||||||
)
|
)
|
||||||
if system_client:
|
if system_client:
|
||||||
try:
|
try:
|
||||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
|||||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
from core.tools.utils.system_encryption import decrypt_system_params
|
||||||
from core.trigger.entities.api_entities import (
|
from core.trigger.entities.api_entities import (
|
||||||
TriggerProviderApiEntity,
|
TriggerProviderApiEntity,
|
||||||
TriggerProviderSubscriptionApiEntity,
|
TriggerProviderSubscriptionApiEntity,
|
||||||
@ -635,7 +635,7 @@ class TriggerProviderService:
|
|||||||
|
|
||||||
if system_client:
|
if system_client:
|
||||||
try:
|
try:
|
||||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from models.workflow import Workflow
|
|||||||
from services.async_workflow_service import AsyncWorkflowService
|
from services.async_workflow_service import AsyncWorkflowService
|
||||||
from services.end_user_service import EndUserService
|
from services.end_user_service import EndUserService
|
||||||
from services.errors.app import QuotaExceededError
|
from services.errors.app import QuotaExceededError
|
||||||
|
from services.quota_service import QuotaService
|
||||||
from services.trigger.app_trigger_service import AppTriggerService
|
from services.trigger.app_trigger_service import AppTriggerService
|
||||||
from services.workflow.entities import WebhookTriggerData
|
from services.workflow.entities import WebhookTriggerData
|
||||||
|
|
||||||
@ -798,45 +799,47 @@ class WebhookService:
|
|||||||
Exception: If workflow execution fails
|
Exception: If workflow execution fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||||
# Prepare inputs for the webhook node
|
|
||||||
# The webhook node expects webhook_data in the inputs
|
|
||||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
|
||||||
|
|
||||||
# Create trigger data
|
trigger_data = WebhookTriggerData(
|
||||||
trigger_data = WebhookTriggerData(
|
app_id=webhook_trigger.app_id,
|
||||||
app_id=webhook_trigger.app_id,
|
workflow_id=workflow.id,
|
||||||
workflow_id=workflow.id,
|
root_node_id=webhook_trigger.node_id,
|
||||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
inputs=workflow_inputs,
|
||||||
inputs=workflow_inputs,
|
tenant_id=webhook_trigger.tenant_id,
|
||||||
tenant_id=webhook_trigger.tenant_id,
|
)
|
||||||
|
|
||||||
|
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||||
|
type=InvokeFrom.TRIGGER,
|
||||||
|
tenant_id=webhook_trigger.tenant_id,
|
||||||
|
app_id=webhook_trigger.app_id,
|
||||||
|
user_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||||
|
except QuotaExceededError:
|
||||||
|
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||||
|
logger.info(
|
||||||
|
"Tenant %s rate limited, skipping webhook trigger %s",
|
||||||
|
webhook_trigger.tenant_id,
|
||||||
|
webhook_trigger.webhook_id,
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
try:
|
||||||
type=InvokeFrom.TRIGGER,
|
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
|
||||||
tenant_id=webhook_trigger.tenant_id,
|
# trigger_workflow_async need to handle multipe session commits internally
|
||||||
app_id=webhook_trigger.app_id,
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
user_id=None,
|
AsyncWorkflowService.trigger_workflow_async(
|
||||||
)
|
session,
|
||||||
|
end_user,
|
||||||
# consume quota before triggering workflow execution
|
trigger_data,
|
||||||
try:
|
|
||||||
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
|
|
||||||
except QuotaExceededError:
|
|
||||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
|
||||||
logger.info(
|
|
||||||
"Tenant %s rate limited, skipping webhook trigger %s",
|
|
||||||
webhook_trigger.tenant_id,
|
|
||||||
webhook_trigger.webhook_id,
|
|
||||||
)
|
)
|
||||||
raise
|
quota_charge.commit()
|
||||||
|
except Exception:
|
||||||
# Trigger workflow execution asynchronously
|
quota_charge.refund()
|
||||||
AsyncWorkflowService.trigger_workflow_async(
|
raise
|
||||||
session,
|
|
||||||
end_user,
|
|
||||||
trigger_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from graphon.model_runtime.entities.model_entities import ModelType
|
|||||||
from models import UploadFile
|
from models import UploadFile
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
from models.enums import SegmentType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -178,7 +179,7 @@ class VectorService:
|
|||||||
index_node_hash=child_chunk.metadata["doc_hash"],
|
index_node_hash=child_chunk.metadata["doc_hash"],
|
||||||
content=child_chunk.page_content,
|
content=child_chunk.page_content,
|
||||||
word_count=len(child_chunk.page_content),
|
word_count=len(child_chunk.page_content),
|
||||||
type="automatic",
|
type=SegmentType.AUTOMATIC,
|
||||||
created_by=dataset_document.created_by,
|
created_by=dataset_document.created_by,
|
||||||
)
|
)
|
||||||
db.session.add(child_segment)
|
db.session.add(child_segment)
|
||||||
@ -222,6 +223,7 @@ class VectorService:
|
|||||||
)
|
)
|
||||||
documents.append(new_child_document)
|
documents.append(new_child_document)
|
||||||
for update_child_chunk in update_child_chunks:
|
for update_child_chunk in update_child_chunks:
|
||||||
|
assert update_child_chunk.index_node_id
|
||||||
child_document = Document(
|
child_document = Document(
|
||||||
page_content=update_child_chunk.content,
|
page_content=update_child_chunk.content,
|
||||||
metadata={
|
metadata={
|
||||||
@ -234,6 +236,7 @@ class VectorService:
|
|||||||
documents.append(child_document)
|
documents.append(child_document)
|
||||||
delete_node_ids.append(update_child_chunk.index_node_id)
|
delete_node_ids.append(update_child_chunk.index_node_id)
|
||||||
for delete_child_chunk in delete_child_chunks:
|
for delete_child_chunk in delete_child_chunks:
|
||||||
|
assert delete_child_chunk.index_node_id
|
||||||
delete_node_ids.append(delete_child_chunk.index_node_id)
|
delete_node_ids.append(delete_child_chunk.index_node_id)
|
||||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||||
# update vector index
|
# update vector index
|
||||||
@ -246,6 +249,7 @@ class VectorService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
|
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
|
||||||
vector = Vector(dataset=dataset)
|
vector = Vector(dataset=dataset)
|
||||||
|
assert child_chunk.index_node_id
|
||||||
vector.delete_by_ids([child_chunk.index_node_id])
|
vector.delete_by_ids([child_chunk.index_node_id])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user