Merge branch 'main' into tp

This commit is contained in:
JzoNg 2026-04-27 10:20:08 +08:00
commit bdecea34a3
316 changed files with 12484 additions and 7806 deletions

View File

@ -16,7 +16,7 @@ concurrency:
jobs: jobs:
api-unit: api-unit:
name: API Unit Tests name: API Unit Tests
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
env: env:
COVERAGE_FILE: coverage-unit COVERAGE_FILE: coverage-unit
defaults: defaults:
@ -62,7 +62,7 @@ jobs:
api-integration: api-integration:
name: API Integration Tests name: API Integration Tests
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
env: env:
COVERAGE_FILE: coverage-integration COVERAGE_FILE: coverage-integration
STORAGE_TYPE: opendal STORAGE_TYPE: opendal
@ -137,7 +137,7 @@ jobs:
api-coverage: api-coverage:
name: API Coverage name: API Coverage
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
needs: needs:
- api-unit - api-unit
- api-integration - api-integration

View File

@ -13,7 +13,7 @@ permissions:
jobs: jobs:
autofix: autofix:
if: github.repository == 'langgenius/dify' if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Complete merge group check - name: Complete merge group check
if: github.event_name == 'merge_group' if: github.event_name == 'merge_group'

View File

@ -26,6 +26,9 @@ jobs:
build: build:
runs-on: ${{ matrix.runs_on }} runs-on: ${{ matrix.runs_on }}
if: github.repository == 'langgenius/dify' if: github.repository == 'langgenius/dify'
permissions:
contents: read
id-token: write
strategy: strategy:
matrix: matrix:
include: include:
@ -35,28 +38,28 @@ jobs:
build_context: "{{defaultContext}}:api" build_context: "{{defaultContext}}:api"
file: "Dockerfile" file: "Dockerfile"
platform: linux/amd64 platform: linux/amd64
runs_on: ubuntu-latest runs_on: depot-ubuntu-24.04-4
- service_name: "build-api-arm64" - service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME" image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api" artifact_context: "api"
build_context: "{{defaultContext}}:api" build_context: "{{defaultContext}}:api"
file: "Dockerfile" file: "Dockerfile"
platform: linux/arm64 platform: linux/arm64
runs_on: ubuntu-24.04-arm runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-amd64" - service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME" image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web" artifact_context: "web"
build_context: "{{defaultContext}}" build_context: "{{defaultContext}}"
file: "web/Dockerfile" file: "web/Dockerfile"
platform: linux/amd64 platform: linux/amd64
runs_on: ubuntu-latest runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-arm64" - service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME" image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web" artifact_context: "web"
build_context: "{{defaultContext}}" build_context: "{{defaultContext}}"
file: "web/Dockerfile" file: "web/Dockerfile"
platform: linux/arm64 platform: linux/arm64
runs_on: ubuntu-24.04-arm runs_on: depot-ubuntu-24.04-4
steps: steps:
- name: Prepare - name: Prepare
@ -70,8 +73,8 @@ jobs:
username: ${{ env.DOCKERHUB_USER }} username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }} password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Depot CLI
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 uses: depot/setup-action@v1
- name: Extract metadata for Docker - name: Extract metadata for Docker
id: meta id: meta
@ -81,16 +84,15 @@ jobs:
- name: Build Docker image - name: Build Docker image
id: build id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 uses: depot/build-push-action@v1
with: with:
project: ${{ vars.DEPOT_PROJECT_ID }}
context: ${{ matrix.build_context }} context: ${{ matrix.build_context }}
file: ${{ matrix.file }} file: ${{ matrix.file }}
platforms: ${{ matrix.platform }} platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }} build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=${{ matrix.service_name }}
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
- name: Export digest - name: Export digest
env: env:
@ -108,9 +110,33 @@ jobs:
if-no-files-found: error if-no-files-found: error
retention-days: 1 retention-days: 1
fork-build-validate:
if: github.repository != 'langgenius/dify'
runs-on: ubuntu-24.04
strategy:
matrix:
include:
- service_name: "validate-api-amd64"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "validate-web-amd64"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
- name: Validate Docker image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
with:
push: false
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: linux/amd64
create-manifest: create-manifest:
needs: build needs: build
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify' if: github.repository == 'langgenius/dify'
strategy: strategy:
matrix: matrix:

View File

@ -9,7 +9,7 @@ concurrency:
jobs: jobs:
db-migration-test-postgres: db-migration-test-postgres:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Checkout code - name: Checkout code
@ -59,7 +59,7 @@ jobs:
run: uv run --directory api flask upgrade-db run: uv run --directory api flask upgrade-db
db-migration-test-mysql: db-migration-test-mysql:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Checkout code - name: Checkout code

View File

@ -13,7 +13,7 @@ on:
jobs: jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
if: | if: |
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev' github.event.workflow_run.head_branch == 'deploy/agent-dev'

View File

@ -10,7 +10,7 @@ on:
jobs: jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
if: | if: |
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/dev' github.event.workflow_run.head_branch == 'deploy/dev'

View File

@ -13,7 +13,7 @@ on:
jobs: jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
if: | if: |
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/enterprise' github.event.workflow_run.head_branch == 'deploy/enterprise'

View File

@ -10,7 +10,7 @@ on:
jobs: jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
if: | if: |
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl' github.event.workflow_run.head_branch == 'build/feat/hitl'

View File

@ -14,40 +14,69 @@ concurrency:
jobs: jobs:
build-docker: build-docker:
if: github.event.pull_request.head.repo.full_name == github.repository
runs-on: ${{ matrix.runs_on }} runs-on: ${{ matrix.runs_on }}
permissions:
contents: read
id-token: write
strategy: strategy:
matrix: matrix:
include: include:
- service_name: "api-amd64" - service_name: "api-amd64"
platform: linux/amd64 platform: linux/amd64
runs_on: ubuntu-latest runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api" context: "{{defaultContext}}:api"
file: "Dockerfile" file: "Dockerfile"
- service_name: "api-arm64" - service_name: "api-arm64"
platform: linux/arm64 platform: linux/arm64
runs_on: ubuntu-24.04-arm runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api" context: "{{defaultContext}}:api"
file: "Dockerfile" file: "Dockerfile"
- service_name: "web-amd64" - service_name: "web-amd64"
platform: linux/amd64 platform: linux/amd64
runs_on: ubuntu-latest runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}" context: "{{defaultContext}}"
file: "web/Dockerfile" file: "web/Dockerfile"
- service_name: "web-arm64" - service_name: "web-arm64"
platform: linux/arm64 platform: linux/arm64
runs_on: ubuntu-24.04-arm runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}" context: "{{defaultContext}}"
file: "web/Dockerfile" file: "web/Dockerfile"
steps: steps:
- name: Set up Docker Buildx - name: Set up Depot CLI
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 uses: depot/setup-action@v1
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 uses: depot/build-push-action@v1
with: with:
project: ${{ vars.DEPOT_PROJECT_ID }}
push: false push: false
context: ${{ matrix.context }} context: ${{ matrix.context }}
file: ${{ matrix.file }} file: ${{ matrix.file }}
platforms: ${{ matrix.platform }} platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max build-docker-fork:
if: github.event.pull_request.head.repo.full_name != github.repository
runs-on: ubuntu-24.04
permissions:
contents: read
strategy:
matrix:
include:
- service_name: "api-amd64"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
- name: Build Docker Image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
with:
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: linux/amd64

View File

@ -7,7 +7,7 @@ jobs:
permissions: permissions:
contents: read contents: read
pull-requests: write pull-requests: write
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 - uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with: with:

View File

@ -23,7 +23,7 @@ concurrency:
jobs: jobs:
pre_job: pre_job:
name: Skip Duplicate Checks name: Skip Duplicate Checks
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
outputs: outputs:
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }} should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
steps: steps:
@ -39,7 +39,7 @@ jobs:
name: Check Changed Files name: Check Changed Files
needs: pre_job needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true' if: needs.pre_job.outputs.should_skip != 'true'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
outputs: outputs:
api-changed: ${{ steps.changes.outputs.api }} api-changed: ${{ steps.changes.outputs.api }}
e2e-changed: ${{ steps.changes.outputs.e2e }} e2e-changed: ${{ steps.changes.outputs.e2e }}
@ -141,7 +141,7 @@ jobs:
- pre_job - pre_job
- check-changes - check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true' if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Report skipped API tests - name: Report skipped API tests
run: echo "No API-related changes detected; skipping API tests." run: echo "No API-related changes detected; skipping API tests."
@ -154,7 +154,7 @@ jobs:
- check-changes - check-changes
- api-tests-run - api-tests-run
- api-tests-skip - api-tests-skip
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Finalize API Tests status - name: Finalize API Tests status
env: env:
@ -201,7 +201,7 @@ jobs:
- pre_job - pre_job
- check-changes - check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true' if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Report skipped web tests - name: Report skipped web tests
run: echo "No web-related changes detected; skipping web tests." run: echo "No web-related changes detected; skipping web tests."
@ -214,7 +214,7 @@ jobs:
- check-changes - check-changes
- web-tests-run - web-tests-run
- web-tests-skip - web-tests-skip
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Finalize Web Tests status - name: Finalize Web Tests status
env: env:
@ -260,7 +260,7 @@ jobs:
- pre_job - pre_job
- check-changes - check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true' if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Report skipped web full-stack e2e - name: Report skipped web full-stack e2e
run: echo "No E2E-related changes detected; skipping web full-stack E2E." run: echo "No E2E-related changes detected; skipping web full-stack E2E."
@ -273,7 +273,7 @@ jobs:
- check-changes - check-changes
- web-e2e-run - web-e2e-run
- web-e2e-skip - web-e2e-skip
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Finalize Web Full-Stack E2E status - name: Finalize Web Full-Stack E2E status
env: env:
@ -325,7 +325,7 @@ jobs:
- pre_job - pre_job
- check-changes - check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true' if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Report skipped VDB tests - name: Report skipped VDB tests
run: echo "No VDB-related changes detected; skipping VDB tests." run: echo "No VDB-related changes detected; skipping VDB tests."
@ -338,7 +338,7 @@ jobs:
- check-changes - check-changes
- vdb-tests-run - vdb-tests-run
- vdb-tests-skip - vdb-tests-skip
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Finalize VDB Tests status - name: Finalize VDB Tests status
env: env:
@ -384,7 +384,7 @@ jobs:
- pre_job - pre_job
- check-changes - check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true' if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Report skipped DB migration tests - name: Report skipped DB migration tests
run: echo "No migration-related changes detected; skipping DB migration tests." run: echo "No migration-related changes detected; skipping DB migration tests."
@ -397,7 +397,7 @@ jobs:
- check-changes - check-changes
- db-migration-test-run - db-migration-test-run
- db-migration-test-skip - db-migration-test-skip
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Finalize DB Migration Test status - name: Finalize DB Migration Test status
env: env:

View File

@ -12,7 +12,7 @@ permissions: {}
jobs: jobs:
comment: comment:
name: Comment PR with pyrefly diff name: Comment PR with pyrefly diff
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
permissions: permissions:
actions: read actions: read
contents: read contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs: jobs:
pyrefly-diff: pyrefly-diff:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
permissions: permissions:
contents: read contents: read
issues: write issues: write

View File

@ -12,7 +12,7 @@ permissions: {}
jobs: jobs:
comment: comment:
name: Comment PR with type coverage name: Comment PR with type coverage
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
permissions: permissions:
actions: read actions: read
contents: read contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs: jobs:
pyrefly-type-coverage: pyrefly-type-coverage:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
permissions: permissions:
contents: read contents: read
issues: write issues: write

View File

@ -16,7 +16,7 @@ jobs:
name: Validate PR title name: Validate PR title
permissions: permissions:
pull-requests: read pull-requests: read
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Complete merge group check - name: Complete merge group check
if: github.event_name == 'merge_group' if: github.event_name == 'merge_group'

View File

@ -12,7 +12,7 @@ on:
jobs: jobs:
stale: stale:
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
permissions: permissions:
issues: write issues: write
pull-requests: write pull-requests: write

View File

@ -15,7 +15,7 @@ permissions:
jobs: jobs:
python-style: python-style:
name: Python Style name: Python Style
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Checkout code - name: Checkout code
@ -57,7 +57,7 @@ jobs:
web-style: web-style:
name: Web Style name: Web Style
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
defaults: defaults:
run: run:
working-directory: ./web working-directory: ./web
@ -131,7 +131,7 @@ jobs:
superlinter: superlinter:
name: SuperLinter name: SuperLinter
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
steps: steps:
- name: Checkout code - name: Checkout code

View File

@ -18,7 +18,7 @@ concurrency:
jobs: jobs:
build: build:
name: unit test for Node.js SDK name: unit test for Node.js SDK
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
defaults: defaults:
run: run:

View File

@ -35,7 +35,7 @@ concurrency:
jobs: jobs:
translate: translate:
if: github.repository == 'langgenius/dify' if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync - name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != '' if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101 uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
with: with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -16,7 +16,7 @@ concurrency:
jobs: jobs:
trigger: trigger:
if: github.repository == 'langgenius/dify' if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
timeout-minutes: 5 timeout-minutes: 5
steps: steps:

View File

@ -16,7 +16,7 @@ jobs:
test: test:
name: Full VDB Tests name: Full VDB Tests
if: github.repository == 'langgenius/dify' if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
strategy: strategy:
matrix: matrix:
python-version: python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs: jobs:
test: test:
name: VDB Smoke Tests name: VDB Smoke Tests
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
strategy: strategy:
matrix: matrix:
python-version: python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs: jobs:
test: test:
name: Web Full-Stack E2E name: Web Full-Stack E2E
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
defaults: defaults:
run: run:
shell: bash shell: bash

View File

@ -16,7 +16,7 @@ concurrency:
jobs: jobs:
test: test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }}) name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
env: env:
VITEST_COVERAGE_SCOPE: app-components VITEST_COVERAGE_SCOPE: app-components
strategy: strategy:
@ -54,7 +54,7 @@ jobs:
name: Merge Test Reports name: Merge Test Reports
if: ${{ !cancelled() }} if: ${{ !cancelled() }}
needs: [test] needs: [test]
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
env: env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults: defaults:
@ -92,7 +92,7 @@ jobs:
dify-ui-test: dify-ui-test:
name: dify-ui Tests name: dify-ui Tests
runs-on: ubuntu-latest runs-on: depot-ubuntu-24.04
env: env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults: defaults:

View File

@ -147,7 +147,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
### Deployment with Kubernetes ### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)

View File

@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration # Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

View File

@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db from extensions.ext_database import db
from models import Tenant from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict) oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green")) click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e: except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict) oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green")) click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e: except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
) )
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings): class EndpointConfig(BaseSettings):
""" """
Configuration for various application endpoints and URLs Configuration for various application endpoints and URLs
@ -1379,6 +1400,7 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig, BillingConfig,
CodeExecutionSandboxConfig, CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig, TriggerConfig,
AsyncWorkflowConfig, AsyncWorkflowConfig,
PluginConfig, PluginConfig,

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
action: str

View File

@ -692,6 +692,32 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json") return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name") @console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource): class AppNameApi(Resource):
@console_ns.doc("check_app_name") @console_ns.doc("check_app_name")

View File

@ -8,10 +8,10 @@ from collections.abc import Generator
from flask import Response, jsonify, request from flask import Response, jsonify, request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError from controllers.web.error import InvalidArgumentError, NotFoundError
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode from models.model import AppMode
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response: def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump() payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp()) payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
if form.tenant_id != current_tenant_id: if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found") raise NotFoundError("App not found")
@staticmethod
def _ensure_console_recipient_type(form: Form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
raise NotFoundError("form not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
raise NotFoundError(f"form not found, token={form_token}") raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form) self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant. # The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually. # So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here." assert recipient_type is not None, "recipient_type cannot be None here."

View File

@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
type: TagType = Field(description="Tag type") type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel): class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter") type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword") keyword: str | None = Field(None, description="Search keyword")
@ -70,6 +75,7 @@ register_schema_models(
TagBasePayload, TagBasePayload,
TagBindingPayload, TagBindingPayload,
TagBindingRemovePayload, TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam, TagListQueryParam,
TagResponse, TagResponse,
) )
@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
return "", 204 return "", 204
@console_ns.route("/tag-bindings/create") def _require_tag_binding_edit_permission() -> None:
class TagBindingCreateApi(Resource): """
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings")
class TagBindingCollectionApi(Resource):
"""Canonical collection resource for tag binding creation."""
@console_ns.doc("create_tag_binding")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, _ = current_account_with_tenant() return _create_tag_bindings()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding( @console_ns.route("/tag-bindings/<uuid:id>")
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type) class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove") @console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource): class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, _ = current_account_with_tenant() return _remove_tag_binding()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200

View File

@ -23,9 +23,11 @@ from .app import (
conversation, conversation,
file, file,
file_preview, file_preview,
human_input_form,
message, message,
site, site,
workflow, workflow,
workflow_events,
) )
from .dataset import ( from .dataset import (
dataset, dataset,
@ -50,6 +52,7 @@ __all__ = [
"file", "file",
"file_preview", "file_preview",
"hit_testing", "hit_testing",
"human_input_form",
"index", "index",
"message", "message",
"metadata", "metadata",
@ -58,6 +61,7 @@ __all__ = [
"segment", "segment",
"site", "site",
"workflow", "workflow",
"workflow_events",
] ]
api.add_namespace(service_api_ns) api.add_namespace(service_api_ns)

View File

@ -0,0 +1,137 @@
"""
Service API human input form endpoints.
This module exposes app-token authenticated APIs for fetching and submitting
paused human input forms in workflow/chatflow runs.
"""
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
# Keep app-token callers scoped to the public web-form surface; internal HITL
# routes must continue to flow through console-only authentication.
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
raise NotFound("Form not found")
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token
def get(self, app_model: App, form_token: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form submitted successfully",
400: "Bad request - invalid submission data",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, form_token: str):
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
recipient_type = form.recipient_type
if recipient_type is None:
logger.warning("Recipient type is None for form, form_id=%s", form.id)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=end_user.id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -0,0 +1,142 @@
"""
Service API workflow resume event stream endpoints.
"""
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@service_api_ns.route("/workflow/<string:task_id>/events")
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(
params={
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": (
"Whether to replay from persisted state snapshot, "
'specify `"true"` to include a status snapshot of executed nodes'
),
"continue_on_pause": (
"Whether to keep the stream open across workflow_paused events,"
'specify `"true"` to keep the stream open for `workflow_paused` events.'
),
}
)
@service_api_ns.doc(
responses={
200: "SSE event stream",
401: "Unauthorized - invalid API token",
404: "Workflow run not found",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise NotWorkflowAppError()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFound("Workflow run not found")
if workflow_run.created_by != end_user.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

View File

@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request from flask import Response, request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload from controllers.web.site import serialize_app_site_payload
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter( _FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit", prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,

View File

@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser, user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False, stream: bool = False,
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]: ) -> (
ChatbotAppBlockingResponse
| AdvancedChatPausedBlockingResponse
| Generator[ChatbotAppStreamResponse, None, None]
):
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppBlockingResponse, AdvancedChatPausedBlockingResponse,
AppStreamResponse, AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse, NodeFinishStreamResponse,
NodeStartStreamResponse, NodeStartStreamResponse,
PingStreamResponse, PingStreamResponse,
StreamEvent,
) )
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): class AdvancedChatAppGenerateResponseConverter(
_blocking_response_type = ChatbotAppBlockingResponse AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
:return: :return:
""" """
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
paused_data = blocking_response.data.model_dump(mode="json")
return {
"event": StreamEvent.WORKFLOW_PAUSED.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
"workflow_run_id": blocking_response.data.workflow_run_id,
"data": paused_data,
}
response = { response = {
"event": "message", "event": StreamEvent.MESSAGE.value,
"task_id": blocking_response.task_id, "task_id": blocking_response.task_id,
"id": blocking_response.data.id, "id": blocking_response.data.id,
"message_id": blocking_response.data.message_id, "message_id": blocking_response.data.message_id,
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {}) metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata) if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
return response return response

View File

@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage, WorkflowQueueMessage,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse, MessageAudioEndStreamResponse,
MessageAudioStreamResponse, MessageAudioStreamResponse,
MessageEndStreamResponse, MessageEndStreamResponse,
PingStreamResponse, PingStreamResponse,
StreamResponse, StreamResponse,
WorkflowPauseStreamResponse,
WorkflowTaskState, WorkflowTaskState,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer: if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
AdvancedChatPausedBlockingResponse,
Generator[ChatbotAppStreamResponse, None, None],
]:
""" """
Process generate task pipeline. Process generate task pipeline.
:return: :return:
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse: def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
""" """
Process blocking response. Process blocking response.
:return: :return:
""" """
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator: for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse): if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = {} extras = {}
if stream_response.metadata: if stream_response.metadata:
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else: else:
continue continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.") raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> AdvancedChatPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return AdvancedChatPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=paused_nodes,
reasons=reasons,
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
),
)
def _to_stream_response( def _to_stream_response(
self, generator: Generator[StreamResponse, None, None] self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]: ) -> Generator[ChatbotAppStreamResponse, Any, None]:

View File

@ -1,6 +1,8 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any, cast from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse, AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
) )
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, "conversation_id": chunk.conversation_id,
"message_id": chunk.message_id, "message_id": chunk.message_id,
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, "conversation_id": chunk.conversation_id,
"message_id": chunk.message_id, "message_id": chunk.message_id,

View File

@ -1,7 +1,9 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Union from typing import Any, Union, cast
from pydantic import JsonValue
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC): class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
_blocking_response_type: type[AppBlockingResponse] @classmethod
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
return cast(TBlockingResponse, response)
@classmethod @classmethod
def convert( def convert(
@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
else: else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]: def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return _generate_full_response() return _generate_full_response()
else: else:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response) return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
else: else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]: def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
return metadata return metadata
@classmethod @classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]: def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
""" """
Error to stream response. Error to stream response.
:param e: exception :param e: exception
:return: :return:
""" """
error_responses: dict[type[Exception], dict[str, Any]] = { error_responses: dict[type[Exception], dict[str, JsonValue]] = {
ValueError: {"code": "invalid_param", "status": 400}, ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: { QuotaExceededError: {
@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
} }
# Determine the response based on the type of exception # Determine the response based on the type of exception
data: dict[str, Any] | None = None data: dict[str, JsonValue] | None = None
for k, v in error_responses.items(): for k, v in error_responses.items():
if isinstance(e, k): if isinstance(e, k):
data = v data = v

View File

@ -1,6 +1,8 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any, cast from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse, AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
) )
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, "conversation_id": chunk.conversation_id,
"message_id": chunk.message_id, "message_id": chunk.message_id,
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, "conversation_id": chunk.conversation_id,
"message_id": chunk.message_id, "message_id": chunk.message_id,

View File

@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError): except (TypeError, json.JSONDecodeError):
definition_payload = {} definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
},
)
responses: list[StreamResponse] = [] responses: list[StreamResponse] = []

View File

@ -1,6 +1,8 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any, cast from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse, AppStreamResponse,
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
) )
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
:return: :return:
""" """
response = { response: dict[str, Any] = {
"event": "message", "event": "message",
"task_id": blocking_response.task_id, "task_id": blocking_response.task_id,
"id": blocking_response.data.id, "id": blocking_response.data.id,
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"message_id": chunk.message_id, "message_id": chunk.message_id,
"created_at": chunk.created_at, "created_at": chunk.created_at,
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping" yield "ping"
continue continue
response_chunk = { response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value, "event": sub_stream_response.event.value,
"message_id": chunk.message_id, "message_id": chunk.message_id,
"created_at": chunk.created_at, "created_at": chunk.created_at,

View File

@ -1,6 +1,7 @@
from collections.abc import Callable, Generator, Mapping from collections.abc import Callable, Generator, Iterable, Mapping
from core.app.apps.streaming_utils import stream_topic_events from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.task_entities import StreamEvent
from extensions.ext_redis import get_pubsub_broadcast_channel from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic from libs.broadcast_channel.channel import Topic
from models.model import AppMode from models.model import AppMode
@ -26,6 +27,7 @@ class MessageGenerator:
idle_timeout=300, idle_timeout=300,
ping_interval: float = 10.0, ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None, on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]: ) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id) topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events( return stream_topic_events(
@ -33,4 +35,5 @@ class MessageGenerator:
idle_timeout=idle_timeout, idle_timeout=idle_timeout,
ping_interval=ping_interval, ping_interval=ping_interval,
on_subscribe=on_subscribe, on_subscribe=on_subscribe,
terminal_events=terminal_events,
) )

View File

@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
) )
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump()) return dict(blocking_response.model_dump())
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response

View File

@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceProviderType, DatasourceProviderType,
OnlineDriveBrowseFilesRequest, OnlineDriveBrowseFilesRequest,
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
user: Account | EndUser, user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False, stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: ) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@ -59,7 +59,7 @@ def stream_topic_events(
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]: def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events: if terminal_events is None:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value} return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set() values: set[str] = set()
for item in terminal_events: for item in terminal_events:

View File

@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.helper.trace_id_helper import extract_external_trace_id_from_args
@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser, user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False, stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: ) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
""" """
Handle response. Handle response.
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse, NodeStartStreamResponse,
PingStreamResponse, PingStreamResponse,
WorkflowAppBlockingResponse, WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse, WorkflowAppStreamResponse,
) )
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): class WorkflowAppGenerateResponseConverter(
_blocking_response_type = WorkflowAppBlockingResponse AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
:return: :return:
""" """
return blocking_response.model_dump() return dict(blocking_response.model_dump())
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response

View File

@ -42,12 +42,15 @@ from core.app.entities.queue_entities import (
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
ErrorStreamResponse, ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse, MessageAudioEndStreamResponse,
MessageAudioStreamResponse, MessageAudioStreamResponse,
PingStreamResponse, PingStreamResponse,
StreamResponse, StreamResponse,
TextChunkStreamResponse, TextChunkStreamResponse,
WorkflowAppBlockingResponse, WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse, WorkflowAppStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse, WorkflowPauseStreamResponse,
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) )
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: def process(
self,
) -> Union[
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
]:
""" """
Process generate task pipeline. Process generate task pipeline.
:return: :return:
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
""" """
To blocking response. To blocking response.
:return: :return:
""" """
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator: for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse): if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse): elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse( return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id, workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data( data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id, id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id, workflow_id=self._workflow.id,
status=stream_response.data.status, status=stream_response.data.status,
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps, total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at, created_at=stream_response.data.created_at,
finished_at=None, finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
), ),
) )
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse): elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse( return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id, workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data( data=WorkflowAppBlockingResponse.Data(
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
), ),
) )
return response
else: else:
continue continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.") raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> WorkflowAppPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
created_at = int(runtime_state.start_at)
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=human_input_responses[-1].workflow_run_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
created_at=created_at,
finished_at=None,
paused_nodes=paused_nodes,
reasons=reasons,
),
)
def _to_stream_response( def _to_stream_response(
self, generator: Generator[StreamResponse, None, None] self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]: ) -> Generator[WorkflowAppStreamResponse, None, None]:

View File

@ -1,12 +1,13 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from enum import StrEnum from enum import StrEnum
from typing import Any from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, JsonValue
from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from graphon.entities import WorkflowStartReason from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction from graphon.nodes.human_input.entities import FormInput, UserAction
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
data: Data data: Data
class HumanInputRequiredPauseReasonPayload(BaseModel):
"""
Public pause-reason payload used by blocking responses when only
``human_input_required`` events are available.
"""
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@classmethod
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
return cls(
form_id=data.form_id,
node_id=data.node_id,
node_title=data.node_title,
form_content=data.form_content,
inputs=data.inputs,
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)
class HumanInputFormFilledResponse(StreamResponse): class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel): class Data(BaseModel):
""" """
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
def to_ignore_detail_dict(self): def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return { return {
"event": self.event.value, "event": self.event.value,
"task_id": self.task_id, "task_id": self.task_id,
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
def to_ignore_detail_dict(self): def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return { return {
"event": self.event.value, "event": self.event.value,
"task_id": self.task_id, "task_id": self.task_id,
@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data data: Data
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
"""
ChatbotAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
message_id: str
workflow_run_id: str
answer: str
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
status: WorkflowExecutionStatus
elapsed_time: float
total_tokens: int
total_steps: int
data: Data
class CompletionAppBlockingResponse(AppBlockingResponse): class CompletionAppBlockingResponse(AppBlockingResponse):
""" """
CompletionAppBlockingResponse entity CompletionAppBlockingResponse entity
@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data data: Data
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
"""
WorkflowAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
workflow_id: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_at: int
finished_at: int | None
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse): class AgentLogStreamResponse(StreamResponse):
""" """
AgentLogStreamResponse entity AgentLogStreamResponse entity

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from copy import deepcopy
from typing import Any from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider: class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str tenant_id: str
provider_manager: ProviderManager provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__( def __init__(
self, self,
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id, user_id=run_context.user_id,
) )
self.provider_manager = provider_manager self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id) provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name) provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration: if not provider_configuration:
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
if credentials is None: if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials return credentials
@ -65,7 +84,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager( provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id, tenant_id=run_context.tenant_id,
user_id=run_context.user_id, user_id=run_context.user_id,
) ),
enable_credentials_cache=True,
) )
self.model_manager = model_manager self.model_manager = model_manager
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id, tenant_id=run_context.tenant_id,
user_id=run_context.user_id, user_id=run_context.user_id,
) )
model_manager = ModelManager(provider_manager=provider_manager) model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return ( return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@ -0,0 +1,41 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import ( from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT, CONVERSATION_TITLE_PROMPT,
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
GENERATOR_QA_PROMPT, GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM, LLM_MODIFY_CODE_SYSTEM,
@ -217,8 +215,8 @@ class LLMGenerator:
else: else:
# Default-model generation keeps the built-in suggested-questions tuning. # Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = { model_parameters = {
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS, "max_tokens": 2560,
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE, "temperature": 0.0,
} }
stop = [] stop = []

View File

@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser: class SuggestedQuestionsAfterAnswerOutputParser:
def __init__(self, instruction_prompt: str | None = None) -> None: def __init__(self, instruction_prompt: str | None = None) -> None:
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
@staticmethod
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
if not instruction_prompt or not instruction_prompt.strip():
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return self._instruction_prompt return self._instruction_prompt

View File

@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n' '["question1","question2","question3"]\n'
) )
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = ( GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge" "<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step." " in the long text. Please think step by step."

View File

@ -1,5 +1,6 @@
import logging import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config from configs import dify_config
@ -36,11 +37,13 @@ class ModelInstance:
Model instance class. Model instance class.
""" """
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle self.provider_model_bundle = provider_model_bundle
self.model_name = model self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields. # Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {} self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = () self.stop: Sequence[str] = ()
@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager: class ModelManager:
def __init__(self, provider_manager: ProviderManager): """Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod @classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type tenant_id=tenant_id, provider=provider, model_type=model_type
) )
model_instance = ModelInstance(provider_model_bundle, model) cred_cache_key = (tenant_id, provider, model_type.value, model)
return model_instance
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
""" """

View File

@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
if dataset_keyword_table: if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict: if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"]) data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
else: else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable( dataset_keyword_table = DatasetKeywordTable(

View File

@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf.""" """Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags( keywords = self._tfidf.extract_tags(
sentence=text, sentence=text,
topK=max_keywords_per_chunk, topK=max_keywords_per_chunk or 10,
) )
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords) keywords = cast(list[str], keywords)

View File

@ -551,6 +551,7 @@ class RetrievalService:
child_index_nodes = session.execute(child_chunk_stmt).scalars().all() child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes: for i in child_index_nodes:
assert i.index_node_id
segment_ids.append(i.segment_id) segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map: if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i) child_chunk_map[i.segment_id].append(i)

View File

@ -11,6 +11,7 @@ from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.enums import SegmentType
class DatasetDocumentStore: class DatasetDocumentStore:
@ -127,6 +128,7 @@ class DatasetDocumentStore:
if save_child: if save_child:
if doc.children: if doc.children:
for position, child in enumerate(doc.children, start=1): for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk( child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id, tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id, dataset_id=self._dataset.id,
@ -137,7 +139,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"), index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content, content=child.page_content,
word_count=len(child.page_content), word_count=len(child.page_content),
type="automatic", type=SegmentType.AUTOMATIC,
created_by=self._user_id, created_by=self._user_id,
) )
db.session.add(child_segment) db.session.add(child_segment)
@ -163,6 +165,7 @@ class DatasetDocumentStore:
) )
# add new child chunks # add new child chunks
for position, child in enumerate(doc.children, start=1): for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk( child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id, tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id, dataset_id=self._dataset.id,
@ -173,7 +176,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"), index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content, content=child.page_content,
word_count=len(child.page_content), word_count=len(child.page_content),
type="automatic", type=SegmentType.AUTOMATIC,
created_by=self._user_id, created_by=self._user_id,
) )
db.session.add(child_segment) db.session.add(child_segment)

View File

@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=dataset_tools, tools=dataset_tools,
stream=False, stream=False, # pyright: ignore[reportArgumentType]
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
) )
usage = result.usage or LLMUsage.empty_usage() usage = result.usage or LLMUsage.empty_usage()

View File

@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception): class EncryptionError(Exception):
"""OAuth encryption/decryption specific error""" """Encryption/decryption specific error"""
pass pass
class SystemOAuthEncrypter: class SystemEncrypter:
""" """
A simple OAuth parameters encrypter using AES-CBC encryption. A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY. using AES-CBC mode with a key derived from the application's SECRET_KEY.
""" """
def __init__(self, secret_key: str | None = None): def __init__(self, secret_key: str | None = None):
""" """
Initialize the OAuth encrypter. Initialize the encrypter.
Args: Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256 # Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest() self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: def encrypt_params(self, params: Mapping[str, Any]) -> str:
""" """
Encrypt OAuth parameters. Encrypt parameters.
Args: Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns: Returns:
Base64-encoded encrypted string Base64-encoded encrypted string
Raises: Raises:
OAuthEncryptionError: If encryption fails EncryptionError: If encryption fails
ValueError: If oauth_params is invalid ValueError: If params is invalid
""" """
try: try:
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv) cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data # Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data) encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data # Combine IV and encrypted data
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode() return base64.b64encode(combined).decode()
except Exception as e: except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
""" """
Decrypt OAuth parameters. Decrypt parameters.
Args: Args:
encrypted_data: Base64-encoded encrypted string encrypted_data: Base64-encoded encrypted string
Returns: Returns:
Decrypted OAuth parameters dictionary Decrypted parameters dictionary
Raises: Raises:
OAuthEncryptionError: If decryption fails EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid ValueError: If encrypted_data is invalid
""" """
if not isinstance(encrypted_data, str): if not isinstance(encrypted_data, str):
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size) unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON # Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict): if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary") raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params return params
except Exception as e: except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances # Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
""" """
Create an OAuth encrypter instance. Create an encrypter instance.
Args: Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns: Returns:
SystemOAuthEncrypter instance SystemEncrypter instance
""" """
return SystemOAuthEncrypter(secret_key=secret_key) return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility) # Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None _encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter: def get_system_encrypter() -> SystemEncrypter:
""" """
Get the global OAuth encrypter instance. Get the global encrypter instance.
Returns: Returns:
SystemOAuthEncrypter instance SystemEncrypter instance
""" """
global _oauth_encrypter global _encrypter
if _oauth_encrypter is None: if _encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter() _encrypter = SystemEncrypter()
return _oauth_encrypter return _encrypter
# Convenience functions for backward compatibility # Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: def encrypt_system_params(params: Mapping[str, Any]) -> str:
""" """
Encrypt OAuth parameters using the global encrypter. Encrypt parameters using the global encrypter.
Args: Args:
oauth_params: OAuth parameters dictionary params: Parameters dictionary
Returns: Returns:
Base64-encoded encrypted string Base64-encoded encrypted string
""" """
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
""" """
Decrypt OAuth parameters using the global encrypter. Decrypt parameters using the global encrypter.
Args: Args:
encrypted_data: Base64-encoded encrypted string encrypted_data: Base64-encoded encrypted string
Returns: Returns:
Decrypted OAuth parameters dictionary Decrypted parameters dictionary
""" """
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -12,20 +12,16 @@ from collections.abc import Sequence
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
from extensions.ext_database import db from extensions.ext_database import db
from models.human_input import HumanInputFormRecipient, RecipientType from models.human_input import HumanInputFormRecipient, RecipientType
_FORM_TOKEN_PRIORITY = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def load_form_tokens_by_form_id( def load_form_tokens_by_form_id(
form_ids: Sequence[str], form_ids: Sequence[str],
*, *,
session: Session | None = None, session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, str]: ) -> dict[str, str]:
"""Load the preferred access token for each human input form.""" """Load the preferred access token for each human input form."""
unique_form_ids = list(dict.fromkeys(form_ids)) unique_form_ids = list(dict.fromkeys(form_ids))
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
return {} return {}
if session is not None: if session is not None:
return _load_form_tokens_by_form_id(session, unique_form_ids) return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session: with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_tokens_by_form_id(new_session, unique_form_ids) return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: def _load_form_tokens_by_form_id(
tokens_by_form_id: dict[str, tuple[int, str]] = {} session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt): for recipient in session.scalars(stmt):
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) if not recipient.access_token:
if priority is None or not recipient.access_token:
continue continue
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token)
)
candidate = (priority, recipient.access_token) tokens_by_form_id: dict[str, str] = {}
current = tokens_by_form_id.get(recipient.form_id) for form_id, recipients in recipients_by_form_id.items():
if current is None or candidate[0] < current[0]: token = _get_surface_form_token(recipients, surface=surface)
tokens_by_form_id[recipient.form_id] = candidate if token is not None:
tokens_by_form_id[form_id] = token
return tokens_by_form_id
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
def _get_surface_form_token(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface == HumanInputSurface.SERVICE_API:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token
return get_preferred_form_token(recipients)

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from graphon.entities.pause_reason import PauseReasonType
from models.human_input import RecipientType
class HumanInputSurface(StrEnum):
SERVICE_API = "service_api"
CONSOLE = "console"
# Service API is intentionally narrower than other surfaces: app-token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
}
# A single HITL form can have multiple recipient records; this shared priority
# keeps every API surface consistent about which resume token to expose.
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def is_recipient_type_allowed_for_surface(
recipient_type: RecipientType | None,
surface: HumanInputSurface,
) -> bool:
if recipient_type is None:
return False
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
def get_preferred_form_token(
recipients: Sequence[tuple[RecipientType, str]],
) -> str | None:
chosen_token: str | None = None
chosen_priority: int | None = None
for recipient_type, token in recipients:
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
if priority is None or not token:
continue
if chosen_priority is None or priority < chosen_priority:
chosen_priority = priority
chosen_token = token
return chosen_token
def enrich_human_input_pause_reasons(
reasons: Sequence[Mapping[str, Any]],
*,
form_tokens_by_form_id: Mapping[str, str],
expiration_times_by_form_id: Mapping[str, int],
) -> list[dict[str, Any]]:
enriched: list[dict[str, Any]] = []
for reason in reasons:
updated = dict(reason)
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_id = updated.get("form_id")
if isinstance(form_id, str):
updated["form_token"] = form_tokens_by_form_id.get(form_id)
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is not None:
updated["expiration_time"] = expiration_time
enriched.append(updated)
return enriched

View File

@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum): class QuotaType(StrEnum):
""" """
Supported quota types for tenant feature usage. Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
""" """
# Trigger execution quota
TRIGGER = auto() TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto() WORKFLOW = auto()
UNLIMITED = auto() UNLIMITED = auto()
@property @property
def billing_key(self) -> str: def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self: match self:
case QuotaType.TRIGGER: case QuotaType.TRIGGER:
return "trigger_event" return "trigger_event"
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit" return "api_rate_limit"
case _: case _:
raise ValueError(f"Invalid quota type: {self}") raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -1036,7 +1036,7 @@ class DocumentSegment(Base):
return attachment_list return attachment_list
class ChildChunk(Base): class ChildChunk(TypeBase):
__tablename__ = "child_chunks" __tablename__ = "child_chunks"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"), sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -1046,29 +1046,42 @@ class ChildChunk(Base):
) )
# initial fields # initial fields
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
tenant_id = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id = mapped_column(StringUUID, nullable=False) segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
content = mapped_column(LongText, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False)
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
# indexing fields # indexing fields
index_node_id = mapped_column(String(255), nullable=True) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
index_node_hash = mapped_column(String(255), nullable=True) created_at: Mapped[datetime] = mapped_column(
type: Mapped[SegmentType] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
) )
created_by = mapped_column(StringUUID, nullable=False) updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp() DateTime,
nullable=False,
server_default=sa.func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
) )
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) indexing_at: Mapped[datetime | None] = mapped_column(
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) DateTime, nullable=True, insert_default=None, server_default=None, init=False
error = mapped_column(LongText, nullable=True) )
completed_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255),
nullable=False,
server_default=sa.text("'automatic'"),
default=SegmentType.AUTOMATIC,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False)
@property @property
def dataset(self): def dataset(self):

View File

@ -1867,15 +1867,18 @@ class MessageAnnotation(TypeBase):
) )
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False StringUUID,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
) )
app_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID)
question: Mapped[str] = mapped_column(LongText, nullable=False) question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None) conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None) message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
) )

View File

@ -225,8 +225,10 @@ class TestSpanBuilder:
span = builder.build_span(span_data) span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan) assert isinstance(span, ReadableSpan)
assert span.name == "test-span" assert span.name == "test-span"
assert span.context is not None
assert span.context.trace_id == 123 assert span.context.trace_id == 123
assert span.context.span_id == 456 assert span.context.span_id == 456
assert span.parent is not None
assert span.parent.span_id == 789 assert span.parent.span_id == 789
assert span.resource == resource assert span.resource == resource
assert span.attributes == {"attr1": "val1"} assert span.attributes == {"attr1": "val1"}

View File

@ -64,12 +64,13 @@ class TestSpanData:
def test_span_data_missing_required_fields(self): def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
SpanData( SpanData.model_validate(
trace_id=123, {
# span_id missing "trace_id": 123,
name="test_span", "name": "test_span",
start_time=1000, "start_time": 1000,
end_time=2000, "end_time": 2000,
}
) )
def test_span_data_arbitrary_types_allowed(self): def test_span_data_arbitrary_types_allowed(self):

View File

@ -2,12 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
from types import SimpleNamespace from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest import pytest
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import ( from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION, GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE, GEN_AI_INPUT_MESSAGE,
@ -44,7 +46,7 @@ class RecordingTraceClient:
self.endpoint = endpoint self.endpoint = endpoint
self.added_spans: list[object] = [] self.added_spans: list[object] = []
def add_span(self, span) -> None: def add_span(self, span: object) -> None:
self.added_spans.append(span) self.added_spans.append(span)
def api_check(self) -> bool: def api_check(self) -> bool:
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
trace_id=trace_id, trace_id=trace_id,
span_id=span_id, span_id=span_id,
is_remote=False, is_remote=False,
trace_flags=TraceFlags.SAMPLED, trace_flags=TraceFlags(TraceFlags.SAMPLED),
) )
return Link(context) return Link(context)
def _make_trace_metadata(
trace_id: int = 1,
workflow_span_id: int = 2,
session_id: str = "s",
user_id: str = "u",
links: list[Link] | None = None,
) -> TraceMetadata:
return TraceMetadata(
trace_id=trace_id,
workflow_span_id=workflow_span_id,
session_id=session_id,
user_id=user_id,
links=[] if links is None else links,
)
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
return cast(RecordingTraceClient, trace_instance.trace_client)
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = { defaults = {
"workflow_id": "workflow-id", "workflow_id": "workflow-id",
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
trace_instance.workflow_trace(trace_info) trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once() add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1] passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
assert passed_trace_metadata.trace_id == 111 assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222 assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c" assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u" assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == [] assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None) trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info) trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == [] assert _recording_trace_client(trace_instance).added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
) )
trace_instance.message_trace(trace_info) trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2 spans = _recorded_span_data(trace_instance)
message_span, llm_span = trace_instance.trace_client.added_spans assert len(spans) == 2
message_span, llm_span = spans
assert message_span.name == "message" assert message_span.name == "message"
assert message_span.trace_id == 10 assert message_span.trace_id == 10
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None) trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info) trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == [] assert _recording_trace_client(trace_instance).added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1 spans = _recorded_span_data(trace_instance)
span = trace_instance.trace_client.added_spans[0] assert len(spans) == 1
span = spans[0]
assert span.name == "dataset_retrieval" assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query" assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None) trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info) trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == [] assert _recording_trace_client(trace_instance).added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
) )
) )
assert len(trace_instance.trace_client.added_spans) == 1 spans = _recorded_span_data(trace_instance)
span = trace_instance.trace_client.added_spans[0] assert len(spans) == 1
span = spans[0]
assert span.name == "my-tool" assert span.name == "my-tool"
assert span.status == status assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool" assert span.attributes[TOOL_NAME] == "my-tool"
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info() trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock() trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
): ):
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info() trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock() trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info() trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock() trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info() trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock() trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
): ):
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info() trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock() trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = BuiltinNodeTypes.CODE node_execution.node_type = BuiltinNodeTypes.CODE
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK) status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id" node_execution.id = "node-id"
node_execution.title = "title" node_execution.title = "title"
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK) status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) trace_metadata = _make_trace_metadata(links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id" node_execution.id = "node-id"
node_execution.title = "my-tool" node_execution.title = "my-tool"
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
) )
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id" node_execution.id = "node-id"
node_execution.title = "retrieval" node_execution.title = "retrieval"
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id" node_execution.id = "node-id"
node_execution.title = "llm" node_execution.title = "llm"
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
status = Status(StatusCode.OK) status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) trace_metadata = _make_trace_metadata()
# CASE 1: With message_id # CASE 1: With message_id
trace_info = _make_workflow_trace_info( trace_info = _make_workflow_trace_info(
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
) )
trace_instance.add_workflow_span(trace_info, trace_metadata) trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2 client = _recording_trace_client(trace_instance)
message_span = trace_instance.trace_client.added_spans[0] spans = _recorded_span_data(trace_instance)
workflow_span = trace_instance.trace_client.added_spans[1] assert len(spans) == 2
message_span = spans[0]
workflow_span = spans[1]
assert message_span.name == "message" assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER assert message_span.span_kind == SpanKind.SERVER
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
assert workflow_span.span_kind == SpanKind.INTERNAL assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20 assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear() client.added_spans.clear()
# CASE 2: Without message_id # CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None) trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1 spans = _recorded_span_data(trace_instance)
span = trace_instance.trace_client.added_spans[0] assert len(spans) == 1
span = spans[0]
assert span.name == "workflow" assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None assert span.parent_span_id is None
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info) trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1 spans = _recorded_span_data(trace_instance)
span = trace_instance.trace_client.added_spans[0] assert len(spans) == 1
span = spans[0]
assert span.name == "suggested_question" assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@ -1,4 +1,6 @@
import json import json
from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
from dify_trace_aliyun.entities.semconv import ( from dify_trace_aliyun.entities.semconv import (
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
def test_format_retrieval_documents(): def test_format_retrieval_documents():
# Not a list # Not a list
assert format_retrieval_documents("not a list") == [] assert format_retrieval_documents(cast(list[object], "not a list")) == []
# Valid list # Valid list
docs = [ docs = [
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
def test_format_input_messages(): def test_format_input_messages():
# Not a dict # Not a dict
assert format_input_messages(None) == serialize_json_data([]) assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No prompts # No prompts
assert format_input_messages({}) == serialize_json_data([]) assert format_input_messages({}) == serialize_json_data([])
@ -244,7 +246,7 @@ def test_format_input_messages():
def test_format_output_messages(): def test_format_output_messages():
# Not a dict # Not a dict
assert format_output_messages(None) == serialize_json_data([]) assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No text # No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])

View File

@ -25,13 +25,13 @@ class TestAliyunConfig:
def test_missing_required_fields(self): def test_missing_required_fields(self):
"""Test that required fields are enforced""" """Test that required fields are enforced"""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
AliyunConfig() AliyunConfig.model_validate({})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license") AliyunConfig.model_validate({"license_key": "test_license"})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
def test_app_name_validation_empty(self): def test_app_name_validation_empty(self):
"""Test app_name validation with empty value""" """Test app_name validation with empty value"""

View File

@ -1,4 +1,5 @@
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import cast
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -129,7 +130,7 @@ def test_set_span_status():
return "SilentErrorRepr" return "SilentErrorRepr"
span.reset_mock() span.reset_mock()
set_span_status(span, SilentError()) set_span_status(span, cast(Exception | str | None, SilentError()))
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"

View File

@ -28,13 +28,13 @@ class TestLangfuseConfig:
def test_missing_required_fields(self): def test_missing_required_fields(self):
"""Test that required fields are enforced""" """Test that required fields are enforced"""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
LangfuseConfig() LangfuseConfig.model_validate({})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
LangfuseConfig(public_key="public") LangfuseConfig.model_validate({"public_key": "public"})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret") LangfuseConfig.model_validate({"secret_key": "secret"})
def test_host_validation_empty(self): def test_host_validation_empty(self):
"""Test host validation with empty value""" """Test host validation with empty value"""

View File

@ -2,6 +2,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from types import SimpleNamespace from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from dify_trace_langfuse.config import LangfuseConfig from dify_trace_langfuse.config import LangfuseConfig
@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
assert trace._get_completion_start_time(start_time, None) is None assert trace._get_completion_start_time(start_time, None) is None
assert trace._get_completion_start_time(start_time, -1) is None assert trace._get_completion_start_time(start_time, -1) is None
assert trace._get_completion_start_time(start_time, "invalid") is None assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None

View File

@ -21,13 +21,13 @@ class TestLangSmithConfig:
def test_missing_required_fields(self): def test_missing_required_fields(self):
"""Test that required fields are enforced""" """Test that required fields are enforced"""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
LangSmithConfig() LangSmithConfig.model_validate({})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
LangSmithConfig(api_key="key") LangSmithConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
LangSmithConfig(project="project") LangSmithConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self): def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS""" """Test endpoint validation only allows HTTPS"""

View File

@ -599,7 +599,6 @@ class TestMessageTrace:
span = MagicMock() span = MagicMock()
mock_tracing["start"].return_value = span mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token" mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_instance.message_trace(_make_message_trace_info()) trace_instance.message_trace(_make_message_trace_info())
mock_tracing["start"].assert_called_once() mock_tracing["start"].assert_called_once()
@ -609,7 +608,6 @@ class TestMessageTrace:
span = MagicMock() span = MagicMock()
mock_tracing["start"].return_value = span mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token" mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(error="something broke") trace_info = _make_message_trace_info(error="something broke")
trace_instance.message_trace(trace_info) trace_instance.message_trace(trace_info)
@ -620,7 +618,6 @@ class TestMessageTrace:
span = MagicMock() span = MagicMock()
mock_tracing["start"].return_value = span mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token" mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
monkeypatch.setenv("FILES_URL", "http://files.test") monkeypatch.setenv("FILES_URL", "http://files.test")
file_data = SimpleNamespace(url="path/to/file.png") file_data = SimpleNamespace(url="path/to/file.png")
@ -638,7 +635,6 @@ class TestMessageTrace:
span = MagicMock() span = MagicMock()
mock_tracing["start"].return_value = span mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token" mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(file_list=None, message_file_data=None) trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
trace_instance.message_trace(trace_info) trace_instance.message_trace(trace_info)
@ -651,7 +647,6 @@ class TestMessageTrace:
end_user = MagicMock() end_user = MagicMock()
end_user.session_id = "session-xyz" end_user.session_id = "session-xyz"
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
trace_info = _make_message_trace_info( trace_info = _make_message_trace_info(
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
@ -664,7 +659,6 @@ class TestMessageTrace:
span = MagicMock() span = MagicMock()
mock_tracing["start"].return_value = span mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token" mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info( trace_info = _make_message_trace_info(
metadata={"from_account_id": "acc-1"}, metadata={"from_account_id": "acc-1"},

View File

@ -12,6 +12,7 @@ from __future__ import annotations
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
return instance return instance
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_trace)
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_span)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _seed_to_uuid4 # _seed_to_uuid4
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
def test_root_span_is_created(self): def test_root_span_is_created(self):
trace_info = _make_workflow_trace_info(message_id=None) trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info) instance = self._run(trace_info)
assert instance.add_span.called assert _add_span_mock(instance).called
def test_root_span_id_matches_expected(self): def test_root_span_id_matches_expected(self):
trace_info = _make_workflow_trace_info(message_id=None) trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info) instance = self._run(trace_info)
expected = self._expected_root_span_id(trace_info) expected = self._expected_root_span_id(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0] root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected assert root_span_kwargs["id"] == expected
def test_root_span_has_no_parent(self): def test_root_span_has_no_parent(self):
trace_info = _make_workflow_trace_info(message_id=None) trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info) instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0] root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["parent_span_id"] is None assert root_span_kwargs["parent_span_id"] is None
def test_trace_name_is_workflow_trace(self): def test_trace_name_is_workflow_trace(self):
@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
trace_info = _make_workflow_trace_info(message_id=None) trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info) instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0] trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_name_is_workflow_trace(self): def test_root_span_name_is_workflow_trace(self):
trace_info = _make_workflow_trace_info(message_id=None) trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info) instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0] root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_has_workflow_tag(self): def test_root_span_has_workflow_tag(self):
trace_info = _make_workflow_trace_info(message_id=None) trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info) instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0] root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert "workflow" in root_span_kwargs["tags"] assert "workflow" in root_span_kwargs["tags"]
def test_node_execution_spans_are_parented_to_root(self): def test_node_execution_spans_are_parented_to_root(self):
@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec]) instance = self._run(trace_info, node_executions=[node_exec])
# call_args_list[0] = root span, [1] = node execution span # call_args_list[0] = root span, [1] = node execution span
assert instance.add_span.call_count == 2 add_span = _add_span_mock(instance)
node_span_kwargs = instance.add_span.call_args_list[1][0][0] assert add_span.call_count == 2
node_span_kwargs = add_span.call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id assert node_span_kwargs["parent_span_id"] == expected_root_span_id
def test_node_span_not_parented_to_workflow_app_log_id(self): def test_node_span_not_parented_to_workflow_app_log_id(self):
@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec]) instance = self._run(trace_info, node_executions=[node_exec])
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id) old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
node_span_kwargs = instance.add_span.call_args_list[1][0][0] node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] != old_parent_id assert node_span_kwargs["parent_span_id"] != old_parent_id
def test_root_span_id_differs_from_trace_id(self): def test_root_span_id_differs_from_trace_id(self):
@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
instance = self._run(trace_info) instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0] trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
def test_root_span_uses_workflow_run_id_directly(self): def test_root_span_uses_workflow_run_id_directly(self):
@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info) instance = self._run(trace_info)
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
root_span_kwargs = instance.add_span.call_args_list[0][0][0] root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected_root_span_id assert root_span_kwargs["id"] == expected_root_span_id
def test_root_span_id_differs_from_no_message_id_case(self): def test_root_span_id_differs_from_no_message_id_case(self):
@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info, node_executions=[node_exec]) instance = self._run(trace_info, node_executions=[node_exec])
node_span_kwargs = instance.add_span.call_args_list[1][0][0] node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id assert node_span_kwargs["parent_span_id"] == expected_root_span_id

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import sys import sys
import types import types
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, TypedDict, cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
from dify_trace_tencent.entities.tencent_trace_entity import SpanData from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
metric_reader_instances: list[DummyMetricReader] = [] metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = [] meter_provider_instances: list[DummyMeterProvider] = []
@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
self.kwargs = kwargs self.kwargs = kwargs
class PatchedCoreComponents(TypedDict):
span_exporter: MagicMock
span_processor: MagicMock
tracer: MagicMock
span: MagicMock
tracer_provider: MagicMock
logger: MagicMock
trace_api: Any
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
"""Drop fake metric modules into sys.modules so the client imports resolve.""" """Drop fake metric modules into sys.modules so the client imports resolve."""
@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
span_exporter = MagicMock(name="span_exporter") span_exporter = MagicMock(name="span_exporter")
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
} }
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
return SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
def _build_client() -> TencentTraceClient: def _build_client() -> TencentTraceClient:
return TencentTraceClient( return TencentTraceClient(
service_name="service", service_name="service",
@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
def test_resolve_grpc_target_handles_errors() -> None: def test_resolve_grpc_target_handles_errors() -> None:
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
client.record_trace_duration(0.5) client.record_trace_duration(0.5)
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
client.hist_llm_duration = MagicMock(name="hist_llm_duration") client.hist_llm_duration = MagicMock(name="hist_llm_duration")
client.hist_llm_duration.record.side_effect = RuntimeError("boom") client.hist_llm_duration.record.side_effect = RuntimeError("boom")
@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
logger.debug.assert_called() logger.debug.assert_called()
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
span = patch_core_components["span"] span = patch_core_components["span"]
span.get_span_context.return_value = "ctx" ctx = _make_span_context(span_id=2)
span.get_span_context.return_value = ctx
data = SpanData( data = SpanData(
trace_id=1, trace_id=1,
@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
span.add_event.assert_called_once() span.add_event.assert_called_once()
span.set_status.assert_called_once() span.set_status.assert_called_once()
span.end.assert_called_once_with(end_time=20) span.end.assert_called_once_with(end_time=20)
assert client.span_contexts[2] == "ctx" assert client.span_contexts[2] == ctx
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
client.span_contexts[10] = "existing" existing_context = _make_span_context(span_id=10)
client.span_contexts[10] = existing_context
span = patch_core_components["span"] span = patch_core_components["span"]
span.get_span_context.return_value = "child" span.get_span_context.return_value = _make_span_context(span_id=11)
data = SpanData( data = SpanData(
trace_id=1, trace_id=1,
@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
client._create_and_export_span(data) client._create_and_export_span(data)
trace_api = patch_core_components["trace_api"] trace_api = patch_core_components["trace_api"]
trace_api.NonRecordingSpan.assert_called_once_with("existing") trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
trace_api.set_span_in_context.assert_called_once() trace_api.set_span_in_context.assert_called_once()
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
span = patch_core_components["span"] span = patch_core_components["span"]
span.get_span_context.return_value = "ctx" span.get_span_context.return_value = _make_span_context(span_id=2)
client.tracer.start_span.side_effect = RuntimeError("boom") client.tracer.start_span.side_effect = RuntimeError("boom")
client._create_and_export_span( client._create_and_export_span(
@ -385,7 +407,7 @@ def test_get_project_url() -> None:
assert client.get_project_url() == "https://console.cloud.tencent.com/apm" assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
span_processor = patch_core_components["span_processor"] span_processor = patch_core_components["span_processor"]
tracer_provider = patch_core_components["tracer_provider"] tracer_provider = patch_core_components["tracer_provider"]
@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
metric_reader.shutdown.assert_called_once() metric_reader.shutdown.assert_called_once()
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
meter_provider = meter_provider_instances[-1] meter_provider = meter_provider_instances[-1]
meter_provider.shutdown.side_effect = RuntimeError("boom") meter_provider.shutdown.side_effect = RuntimeError("boom")
assert client.metric_reader is not None
client.metric_reader.shutdown.side_effect = RuntimeError("boom") client.metric_reader.shutdown.side_effect = RuntimeError("boom")
client.shutdown() client.shutdown()
@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
assert client.metric_reader is None assert client.metric_reader is None
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
logger.exception.assert_called_once() logger.exception.assert_called_once()
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client() client = _build_client()
span = patch_core_components["span"] span = patch_core_components["span"]
span.get_span_context.return_value = "ctx" span.get_span_context.return_value = _make_span_context(span_id=2)
data = SpanData.model_construct( data = SpanData.model_construct(
trace_id=1, trace_id=1,
@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_llm_duration") hist_mock = MagicMock(name="hist_llm_duration")
client.hist_llm_duration = hist_mock client.hist_llm_duration = hist_mock
client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
_, attrs = hist_mock.record.call_args.args _, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["foo"], str) assert isinstance(attrs["foo"], str)
assert attrs["bar"] == 2 assert attrs["bar"] == 2
@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_trace_duration") hist_mock = MagicMock(name="hist_trace_duration")
client.hist_trace_duration = hist_mock client.hist_trace_duration = hist_mock
client.record_trace_duration(1.0, {"meta": object(), "ok": True}) client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
_, attrs = hist_mock.record.call_args.args _, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["meta"], str) assert isinstance(attrs["meta"], str)
assert attrs["ok"] is True assert attrs["ok"] is True
@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
], ],
) )
def test_record_methods_handle_exceptions( def test_record_methods_handle_exceptions(
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
) -> None: ) -> None:
client = _build_client() client = _build_client()
hist_mock = MagicMock(name=attr_name) hist_mock = MagicMock(name=attr_name)
@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
def test_metrics_initializes_grpc_metric_exporter() -> None: def test_metrics_initializes_grpc_metric_exporter() -> None:
client = _build_client() client = _build_client()
metric_reader = metric_reader_instances[-1] metric_reader = metric_reader_instances[-1]
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) assert isinstance(exporter, DummyGrpcMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert metric_reader.exporter.kwargs["insecure"] is False assert exporter.kwargs["insecure"] is False
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
client = _build_client() client = _build_client()
metric_reader = metric_reader_instances[-1] metric_reader = metric_reader_instances[-1]
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) assert isinstance(exporter, DummyHttpMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint assert exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
client = _build_client() client = _build_client()
metric_reader = metric_reader_instances[-1] metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) assert isinstance(exporter, DummyJsonMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint assert exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in metric_reader.exporter.kwargs assert "preferred_temporality" in exporter.kwargs
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
_ = _build_client() _ = _build_client()
metric_reader = metric_reader_instances[-1] metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in metric_reader.exporter.kwargs assert "preferred_temporality" not in exporter.kwargs
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -31,13 +31,13 @@ class TestWeaveConfig:
def test_missing_required_fields(self): def test_missing_required_fields(self):
"""Test that required fields are enforced""" """Test that required fields are enforced"""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
WeaveConfig() WeaveConfig.model_validate({})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
WeaveConfig(api_key="key") WeaveConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
WeaveConfig(project="project") WeaveConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self): def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS""" """Test endpoint validation only allows HTTPS"""

View File

@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
auth = PasswordAuthenticator(config.user, config.password) auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth) options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options) self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
self._bucket = self._cluster.bucket(config.bucket_name) self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name) self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name self._bucket_name = config.bucket_name
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
try: try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
search_iter = self._scope.search( search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
) )

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any, TypedDict from typing import Any, TypedDict, cast
from packaging import version from packaging import version
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
def _load_collection_fields(self, fields: list[str] | None = None): def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None: if fields is None:
# Load collection fields from remote server # Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name) collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
fields = [field["name"] for field in collection_info["fields"]] fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it # Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY] self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
return False return False
try: try:
milvus_version = self._client.get_server_version() milvus_version_raw = self._client.get_server_version()
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version: if "Zilliz Cloud" in milvus_version:
return True return True

View File

@ -3,7 +3,7 @@ import json
import logging import logging
import re import re
import uuid import uuid
from typing import Any from typing import Any, TypedDict
import jieba.posseg as pseg # type: ignore import jieba.posseg as pseg # type: ignore
import numpy import numpy
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False oracledb.defaults.fetch_lobs = False
class _OraclePoolParams(TypedDict, total=False):
user: str
password: str
dsn: str
min: int
max: int
increment: int
config_dir: str | None
wallet_location: str | None
wallet_password: str | None
class OracleVectorConfig(BaseModel): class OracleVectorConfig(BaseModel):
user: str user: str
password: str password: str
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
return connection return connection
def _create_connection_pool(self, config: OracleVectorConfig): def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = { pool_params = _OraclePoolParams(
"user": config.user, user=config.user,
"password": config.password, password=config.password,
"dsn": config.dsn, dsn=config.dsn,
"min": 1, min=1,
"max": 5, max=5,
"increment": 1, increment=1,
} )
if config.is_autonomous: if config.is_autonomous:
pool_params.update( pool_params["config_dir"] = config.config_dir
{ pool_params["wallet_location"] = config.wallet_location
"config_dir": config.config_dir, pool_params["wallet_password"] = config.wallet_password
"wallet_location": config.wallet_location,
"wallet_password": config.wallet_password,
}
)
return oracledb.create_pool(**pool_params) return oracledb.create_pool(**pool_params)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):

View File

@ -9,6 +9,7 @@ dependencies = [
"boto3>=1.42.91", "boto3>=1.42.91",
"celery>=5.6.3", "celery>=5.6.3",
"croniter>=6.2.2", "croniter>=6.2.2",
"flask>=3.1.3,<4.0.0",
"flask-cors>=6.0.2", "flask-cors>=6.0.2",
"gevent>=26.4.0", "gevent>=26.4.0",
"gevent-websocket>=0.10.1", "gevent-websocket>=0.10.1",
@ -117,7 +118,7 @@ dev = [
"faker>=40.15.0", "faker>=40.15.0",
"lxml-stubs>=0.5.1", "lxml-stubs>=0.5.1",
"basedpyright>=1.39.3", "basedpyright>=1.39.3",
"ruff>=0.15.11", "ruff>=0.15.12",
"pytest>=9.0.3", "pytest>=9.0.3",
"pytest-benchmark>=5.2.3", "pytest-benchmark>=5.2.3",
"pytest-cov>=7.1.0", "pytest-cov>=7.1.0",
@ -144,7 +145,7 @@ dev = [
"types-pexpect>=4.9.0", "types-pexpect>=4.9.0",
"types-protobuf>=7.34.1", "types-protobuf>=7.34.1",
"types-psutil>=7.2.2", "types-psutil>=7.2.2",
"types-psycopg2>=2.9.21", "types-psycopg2>=2.9.21.20260422",
"types-pygments>=2.20.0", "types-pygments>=2.20.0",
"types-pymysql>=1.1.0", "types-pymysql>=1.1.0",
"types-python-dateutil>=2.9.0", "types-python-dateutil>=2.9.0",
@ -157,9 +158,9 @@ dev = [
"types-tensorflow>=2.18.0.20260408", "types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408", "types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0", "types-ujson>=5.10.0",
"boto3-stubs>=1.42.92", "boto3-stubs>=1.42.96",
"types-jmespath>=1.1.0.20260408", "types-jmespath>=1.1.0.20260408",
"hypothesis>=6.152.1", "hypothesis>=6.152.3",
"types_pyOpenSSL>=24.1.0", "types_pyOpenSSL>=24.1.0",
"types_cffi>=2.0.0.20260408", "types_cffi>=2.0.0.20260408",
"types_setuptools>=82.0.0.20260408", "types_setuptools>=82.0.0.20260408",
@ -169,7 +170,7 @@ dev = [
"import-linter>=2.3", "import-linter>=2.3",
"types-redis>=4.6.0.20241004", "types-redis>=4.6.0.20241004",
"celery-types>=0.23.0", "celery-types>=0.23.0",
"mypy>=1.20.1", "mypy>=1.20.2",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0", "pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0", "pytest-xdist>=3.8.0",

View File

@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold from libs.time_parser import get_time_threshold
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.human_input import HumanInputForm from models.human_input import HumanInputForm, HumanInputFormRecipient
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.entities.workflow_pause import WorkflowPauseEntity
@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
def _build_human_input_required_reason( def _build_human_input_required_reason(
reason_model: WorkflowPauseReason, reason_model: WorkflowPauseReason,
form_model: HumanInputForm | None, form_model: HumanInputForm | None,
recipients: Sequence[HumanInputFormRecipient] = (),
) -> HumanInputRequired: ) -> HumanInputRequired:
form_content = "" form_content = ""
inputs = [] inputs = []
@ -89,7 +90,7 @@ def _build_human_input_required_reason(
resolved_default_values = dict(definition.default_values) resolved_default_values = dict(definition.default_values)
node_title = definition.node_title or node_title node_title = definition.node_title or node_title
return HumanInputRequired( reason = HumanInputRequired(
form_id=form_id, form_id=form_id,
form_content=form_content, form_content=form_content,
inputs=inputs, inputs=inputs,
@ -98,6 +99,7 @@ def _build_human_input_required_reason(
node_title=node_title, node_title=node_title,
resolved_default_values=resolved_default_values, resolved_default_values=resolved_default_values,
) )
return reason
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
for form in session.scalars(form_stmt).all(): for form in session.scalars(form_stmt).all():
form_models[form.id] = form form_models[form.id] = form
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(recipient_stmt).all():
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
pause_reasons: list[PauseReason] = [] pause_reasons: list[PauseReason] = []
for reason in pause_reason_models: for reason in pause_reason_models:
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_model = form_models.get(reason.form_id) form_model = form_models.get(reason.form_id)
pause_reasons.append(_build_human_input_required_reason(reason, form_model)) pause_reasons.append(
_build_human_input_required_reason(
reason,
form_model,
recipients_by_form_id.get(reason.form_id, ()),
)
)
else: else:
pause_reasons.append(reason.to_entity()) pause_reasons.append(reason.to_entity())
return pause_reasons return pause_reasons

View File

@ -133,7 +133,14 @@ class AppAnnotationService:
raise ValueError("'question' is required when 'message_id' is not provided") raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id) annotation = MessageAnnotation(
app_id=app.id,
conversation_id=None,
message_id=None,
content=answer,
question=question,
account_id=current_user.id,
)
db.session.add(annotation) db.session.add(annotation)
db.session.commit() db.session.commit()

View File

@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory from core.db import session_factory
from enums.quota_type import QuotaType, unlimited from enums.quota_type import QuotaType
from extensions.otel import AppGenerateHandler, trace_span from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@ -106,7 +107,7 @@ class AppGenerateService:
quota_charge = unlimited() quota_charge = unlimited()
if dify_config.BILLING_ENABLED: if dify_config.BILLING_ENABLED:
try: try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id) quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
except QuotaExceededError: except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@ -116,6 +117,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key() request_id = RateLimit.gen_request_key()
try: try:
request_id = rate_limit.enter(request_id) request_id = rate_limit.enter(request_id)
quota_charge.commit()
effective_mode = ( effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
) )
@ -162,6 +164,7 @@ class AppGenerateService:
invoke_from=invoke_from, invoke_from=invoke_from,
streaming=True, streaming=True,
call_depth=0, call_depth=0,
workflow_run_id=str(uuid.uuid4()),
) )
payload_json = payload.model_dump_json() payload_json = payload.model_dump_json()
@ -183,6 +186,10 @@ class AppGenerateService:
else: else:
# Blocking mode: run synchronously and return JSON instead of SSE # Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch. # Keep behaviour consistent with WORKFLOW blocking branch.
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
advanced_generator = AdvancedChatAppGenerator() advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate( return rate_limit.generate(
advanced_generator.convert_to_event_stream( advanced_generator.convert_to_event_stream(
@ -194,6 +201,7 @@ class AppGenerateService:
invoke_from=invoke_from, invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
streaming=False, streaming=False,
pause_state_config=pause_config,
) )
), ),
request_id=request_id, request_id=request_id,

View File

@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -88,7 +89,10 @@ class AsyncWorkflowService:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}") raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow # 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id) workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
# commit read only session before starting the billig rpc call
session.commit()
# 3. Get dispatcher based on tenant subscription # 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id) dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
@ -131,9 +135,10 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log) trigger_log = trigger_log_repo.create(trigger_log)
session.commit() session.commit()
# 7. Check and consume quota # 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
try: try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id) quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
except QuotaExceededError as e: except QuotaExceededError as e:
# Update trigger log status # Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@ -153,13 +158,18 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue # 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json") task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None try:
if queue_name == QueuePriority.PROFESSIONAL: task: AsyncResult[Any] | None = None
task = execute_workflow_professional.delay(task_data_dict) if queue_name == QueuePriority.PROFESSIONAL:
elif queue_name == QueuePriority.TEAM: task = execute_workflow_professional.delay(task_data_dict)
task = execute_workflow_team.delay(task_data_dict) elif queue_name == QueuePriority.TEAM:
else: # SANDBOX task = execute_workflow_team.delay(task_data_dict)
task = execute_workflow_sandbox.delay(task_data_dict) else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
# 10. Update trigger log with task info # 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED trigger_log.status = WorkflowTriggerStatus.QUEUED
@ -295,13 +305,21 @@ class AsyncWorkflowService:
return [log.to_dict() for log in logs] return [log.to_dict() for log in logs]
@staticmethod @staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow: def _get_workflow(
workflow_service: WorkflowService,
app_model: App,
workflow_id: str | None = None,
session: Session | None = None,
) -> Workflow:
""" """
Get workflow for the app Get workflow for the app
Args: Args:
app_model: App model instance app_model: App model instance
workflow_id: Optional specific workflow ID workflow_id: Optional specific workflow ID
session: Reuse this SQLAlchemy session for the lookup when provided,
so the caller's explicit session bears the connection cost
instead of Flask's request-scoped ``db.session``.
Returns: Returns:
Workflow instance Workflow instance
@ -311,12 +329,12 @@ class AsyncWorkflowService:
""" """
if workflow_id: if workflow_id:
# Get specific published workflow # Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id) workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
if not workflow: if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}") raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else: else:
# Get default published workflow # Get default published workflow
workflow = workflow_service.get_published_workflow(app_model) workflow = workflow_service.get_published_workflow(app_model, session=session)
if not workflow: if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}") raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")

View File

@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict):
expiration_date: int expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict): class _BillingQuota(TypedDict):
size: int size: int
limit: int limit: int
@ -149,11 +193,63 @@ class BillingService:
@classmethod @classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str): def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id} params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod @classmethod
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id} params = {"tenant_id": tenant_id}

View File

@ -3748,6 +3748,7 @@ class SegmentService:
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
) )
) )
assert current_user.current_tenant_id
child_chunk = ChildChunk( child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
@ -3758,7 +3759,7 @@ class SegmentService:
index_node_hash=index_node_hash, index_node_hash=index_node_hash,
content=content, content=content,
word_count=len(content), word_count=len(content),
type="customized", type=SegmentType.CUSTOMIZED,
created_by=current_user.id, created_by=current_user.id,
) )
db.session.add(child_chunk) db.session.add(child_chunk)
@ -3818,6 +3819,7 @@ class SegmentService:
if new_child_chunks_args: if new_child_chunks_args:
child_chunk_count = len(child_chunks) child_chunk_count = len(child_chunks)
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
assert current_user.current_tenant_id
index_node_id = str(uuid.uuid4()) index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(args.content) index_node_hash = helper.generate_text_hash(args.content)
child_chunk = ChildChunk( child_chunk = ChildChunk(
@ -3830,7 +3832,7 @@ class SegmentService:
index_node_hash=index_node_hash, index_node_hash=index_node_hash,
content=args.content, content=args.content,
word_count=len(args.content), word_count=len(args.content),
type="customized", type=SegmentType.CUSTOMIZED,
created_by=current_user.id, created_by=current_user.id,
) )

View File

@ -5,6 +5,7 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config from configs import dify_config
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService: class EnterpriseService:
@classmethod @classmethod
@ttl_cache(ttl=5)
def get_info(cls): def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info") return EnterpriseRequest.send_request("GET", "/info")

View File

@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
enable_change_email: bool = True enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel() plugin_manager: PluginManagerModel = PluginManagerModel()
trial_models: list[str] = [] trial_models: list[str] = []
enable_creators_platform: bool = False
enable_trial_app: bool = False enable_trial_app: bool = False
enable_explore_banner: bool = False enable_explore_banner: bool = False
@ -241,6 +242,9 @@ class FeatureService:
if dify_config.MARKETPLACE_ENABLED: if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True system_features.enable_marketplace = True
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
system_features.enable_creators_platform = True
return system_features return system_features
@classmethod @classmethod
@ -286,7 +290,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id) billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) features_usage_info = BillingService.get_quota_info(tenant_id)
features.billing.enabled = billing_info["enabled"] features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"] features.billing.subscription.plan = billing_info["subscription"]["plan"]

View File

@ -0,0 +1,233 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID from models.provider_ids import ToolProviderID
@ -521,7 +521,7 @@ class BuiltinToolManageService:
) )
if system_client: if system_client:
try: try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e: except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}") raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import ( from core.trigger.entities.api_entities import (
TriggerProviderApiEntity, TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity, TriggerProviderSubscriptionApiEntity,
@ -635,7 +635,7 @@ class TriggerProviderService:
if system_client: if system_client:
try: try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e: except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}") raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -38,6 +38,7 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData from services.workflow.entities import WebhookTriggerData
@ -798,45 +799,47 @@ class WebhookService:
Exception: If workflow execution fails Exception: If workflow execution fails
""" """
try: try:
with Session(db.engine) as session: workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Prepare inputs for the webhook node
# The webhook node expects webhook_data in the inputs
workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Create trigger data trigger_data = WebhookTriggerData(
trigger_data = WebhookTriggerData( app_id=webhook_trigger.app_id,
app_id=webhook_trigger.app_id, workflow_id=workflow.id,
workflow_id=workflow.id, root_node_id=webhook_trigger.node_id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node inputs=workflow_inputs,
inputs=workflow_inputs, tenant_id=webhook_trigger.tenant_id,
tenant_id=webhook_trigger.tenant_id, )
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
) )
raise
end_user = EndUserService.get_or_create_end_user_by_type( try:
type=InvokeFrom.TRIGGER, # NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
tenant_id=webhook_trigger.tenant_id, # trigger_workflow_async need to handle multipe session commits internally
app_id=webhook_trigger.app_id, with Session(db.engine, expire_on_commit=False) as session:
user_id=None, AsyncWorkflowService.trigger_workflow_async(
) session,
end_user,
# consume quota before triggering workflow execution trigger_data,
try:
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
) )
raise quota_charge.commit()
except Exception:
# Trigger workflow execution asynchronously quota_charge.refund()
AsyncWorkflowService.trigger_workflow_async( raise
session,
end_user,
trigger_data,
)
except Exception: except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -16,6 +16,7 @@ from graphon.model_runtime.entities.model_entities import ModelType
from models import UploadFile from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.enums import SegmentType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -178,7 +179,7 @@ class VectorService:
index_node_hash=child_chunk.metadata["doc_hash"], index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content, content=child_chunk.page_content,
word_count=len(child_chunk.page_content), word_count=len(child_chunk.page_content),
type="automatic", type=SegmentType.AUTOMATIC,
created_by=dataset_document.created_by, created_by=dataset_document.created_by,
) )
db.session.add(child_segment) db.session.add(child_segment)
@ -222,6 +223,7 @@ class VectorService:
) )
documents.append(new_child_document) documents.append(new_child_document)
for update_child_chunk in update_child_chunks: for update_child_chunk in update_child_chunks:
assert update_child_chunk.index_node_id
child_document = Document( child_document = Document(
page_content=update_child_chunk.content, page_content=update_child_chunk.content,
metadata={ metadata={
@ -234,6 +236,7 @@ class VectorService:
documents.append(child_document) documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id) delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks: for delete_child_chunk in delete_child_chunks:
assert delete_child_chunk.index_node_id
delete_node_ids.append(delete_child_chunk.index_node_id) delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index # update vector index
@ -246,6 +249,7 @@ class VectorService:
@classmethod @classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset) vector = Vector(dataset=dataset)
assert child_chunk.index_node_id
vector.delete_by_ids([child_chunk.index_node_id]) vector.delete_by_ids([child_chunk.index_node_id])
@classmethod @classmethod

Some files were not shown because too many files have changed in this diff Show More