Merge remote-tracking branch 'upstream/feat/hitl-form-enhancement' into feat/hitl-form-enhancement

This commit is contained in:
QuantumGhost 2026-05-07 16:44:29 +08:00
commit 3f6559dd60
1318 changed files with 94094 additions and 27033 deletions

3
.github/CODEOWNERS vendored
View File

@ -6,6 +6,9 @@
* @crazywoola @laipz8200 @Yeuoly
# ESLint suppression file is maintained by autofix.ci pruning.
/eslint-suppressions.json
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola

View File

@ -4,7 +4,7 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
with:
node-version-file: .nvmrc
cache: true

1
.github/labeler.yml vendored
View File

@ -6,5 +6,4 @@ web:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
api-unit:
name: API Unit Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
COVERAGE_FILE: coverage-unit
defaults:
@ -62,7 +62,7 @@ jobs:
api-integration:
name: API Integration Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
COVERAGE_FILE: coverage-integration
STORAGE_TYPE: opendal
@ -137,7 +137,7 @@ jobs:
api-coverage:
name: API Coverage
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
needs:
- api-unit
- api-integration

View File

@ -13,7 +13,7 @@ permissions:
jobs:
autofix:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'
@ -43,7 +43,6 @@ jobs:
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'
@ -114,7 +113,7 @@ jobs:
find . -name "*.py.bak" -type f -delete
- name: Setup web environment
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
if: github.event_name != 'merge_group'
uses: ./.github/actions/setup-web
- name: ESLint autofix

View File

@ -26,6 +26,9 @@ jobs:
build:
runs-on: ${{ matrix.runs_on }}
if: github.repository == 'langgenius/dify'
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
@ -35,28 +38,28 @@ jobs:
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
steps:
- name: Prepare
@ -70,8 +73,8 @@ jobs:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Set up Depot CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Extract metadata for Docker
id: meta
@ -81,16 +84,15 @@ jobs:
- name: Build Docker image
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=${{ matrix.service_name }}
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
- name: Export digest
env:
@ -108,9 +110,33 @@ jobs:
if-no-files-found: error
retention-days: 1
fork-build-validate:
if: github.repository != 'langgenius/dify'
runs-on: ubuntu-24.04
strategy:
matrix:
include:
- service_name: "validate-api-amd64"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "validate-web-amd64"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Validate Docker image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
push: false
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: linux/amd64
create-manifest:
needs: build
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify'
strategy:
matrix:

View File

@ -9,7 +9,7 @@ concurrency:
jobs:
db-migration-test-postgres:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code
@ -59,7 +59,7 @@ jobs:
run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code
@ -110,6 +110,28 @@ jobs:
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
# hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d`
# to return (container processes started); it does not wait on healthcheck
# status. mysql:8.0's first-time init takes 15-30s, so without an explicit
# wait the migration runs while InnoDB is still initialising and gets
# killed with "Lost connection during query". Poll a real SELECT until it
# succeeds.
- name: Wait for MySQL to accept queries
run: |
set +e
for i in $(seq 1 60); do
if docker run --rm --network host mysql:8.0 \
mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \
-e 'SELECT 1' >/dev/null 2>&1; then
echo "MySQL ready after ${i}s"
exit 0
fi
sleep 1
done
echo "MySQL not ready after 60s; dumping container logs:"
docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql
exit 1
- name: Run DB Migration
env:
DEBUG: true

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/dev'

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/enterprise'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl'

View File

@ -14,28 +14,59 @@ concurrency:
jobs:
build-docker:
if: github.event.pull_request.head.repo.full_name == github.repository
runs-on: ${{ matrix.runs_on }}
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
- service_name: "api-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "api-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "web/Dockerfile"
- service_name: "web-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Depot CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Build Docker Image
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
build-docker-fork:
if: github.event.pull_request.head.repo.full_name != github.repository
runs-on: ubuntu-24.04
permissions:
contents: read
strategy:
matrix:
include:
- service_name: "api-amd64"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
@ -48,6 +79,4 @@ jobs:
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/amd64

View File

@ -7,7 +7,7 @@ jobs:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with:

View File

@ -23,7 +23,7 @@ concurrency:
jobs:
pre_job:
name: Skip Duplicate Checks
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
outputs:
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
steps:
@ -39,7 +39,7 @@ jobs:
name: Check Changed Files
needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
outputs:
api-changed: ${{ steps.changes.outputs.api }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
@ -69,7 +69,6 @@ jobs:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
@ -83,7 +82,6 @@ jobs:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
@ -141,7 +139,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped API tests
run: echo "No API-related changes detected; skipping API tests."
@ -154,7 +152,7 @@ jobs:
- check-changes
- api-tests-run
- api-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize API Tests status
env:
@ -201,7 +199,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped web tests
run: echo "No web-related changes detected; skipping web tests."
@ -214,7 +212,7 @@ jobs:
- check-changes
- web-tests-run
- web-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize Web Tests status
env:
@ -260,7 +258,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped web full-stack e2e
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
@ -273,7 +271,7 @@ jobs:
- check-changes
- web-e2e-run
- web-e2e-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize Web Full-Stack E2E status
env:
@ -325,7 +323,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped VDB tests
run: echo "No VDB-related changes detected; skipping VDB tests."
@ -338,7 +336,7 @@ jobs:
- check-changes
- vdb-tests-run
- vdb-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize VDB Tests status
env:
@ -384,7 +382,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped DB migration tests
run: echo "No migration-related changes detected; skipping DB migration tests."
@ -397,7 +395,7 @@ jobs:
- check-changes
- db-migration-test-run
- db-migration-test-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize DB Migration Test status
env:

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with pyrefly diff
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
actions: read
contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-diff:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
contents: read
issues: write

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
actions: read
contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
contents: read
issues: write

View File

@ -16,7 +16,7 @@ jobs:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'

View File

@ -12,7 +12,7 @@ on:
jobs:
stale:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
issues: write
pull-requests: write

View File

@ -15,7 +15,7 @@ permissions:
jobs:
python-style:
name: Python Style
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code
@ -57,7 +57,7 @@ jobs:
web-style:
name: Web Style
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:
working-directory: ./web
@ -83,7 +83,6 @@ jobs:
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@ -131,7 +130,7 @@ jobs:
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code

View File

@ -9,7 +9,6 @@ on:
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}
@ -18,7 +17,7 @@ concurrency:
jobs:
build:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:

View File

@ -35,7 +35,7 @@ concurrency:
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
timeout-minutes: 120
steps:
@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
timeout-minutes: 5
steps:

View File

@ -16,7 +16,7 @@ jobs:
test:
name: Full VDB Tests
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
strategy:
matrix:
python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: VDB Smoke Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
strategy:
matrix:
python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: Web Full-Stack E2E
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04-4
defaults:
run:
shell: bash

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04-4
env:
VITEST_COVERAGE_SCOPE: app-components
strategy:
@ -54,7 +54,7 @@ jobs:
name: Merge Test Reports
if: ${{ !cancelled() }}
needs: [test]
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04-4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
@ -92,7 +92,7 @@ jobs:
dify-ui-test:
name: dify-ui Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04-4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:

3
.gitignore vendored
View File

@ -219,6 +219,9 @@ node_modules
# plugin migrate
plugins.jsonl
# generated API OpenAPI specs
packages/contracts/openapi/
# mise
mise.toml

1
.npmrc
View File

@ -1 +0,0 @@
save-exact=true

View File

@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
```bash
cd dify
cd docker
cp .env.example .env
docker compose up -d
./dify-compose up -d
```
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
#### Seeking help
@ -137,7 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases.
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana
@ -147,7 +148,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)

View File

@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

View File

@ -113,8 +113,18 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
# Validates name encoding for non-Latin characters.
name = name.strip().encode("utf-8").decode("utf-8") if name else None
# generate random password
new_password = secrets.token_urlsafe(16)
# Generate a random password that satisfies the password policy.
# The iteration limit guards against infinite loops caused by unexpected bugs in valid_password.
for _ in range(100):
new_password = secrets.token_urlsafe(16)
try:
valid_password(new_password)
break
except Exception:
continue
else:
click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red"))
return
# register account
account = RegisterService.register(

View File

@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -1379,6 +1400,7 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,

View File

@ -41,7 +41,8 @@ def guess_file_info_from_response(response: httpx.Response):
# Try to extract filename from URL
parsed_url = urllib.parse.urlparse(url)
url_path = parsed_url.path
filename = os.path.basename(url_path)
# Decode percent-encoded characters in the path segment
filename = urllib.parse.unquote(os.path.basename(url_path))
# If filename couldn't be extracted, use Content-Disposition header
if not filename:

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
action: str

View File

@ -1,4 +1,5 @@
import logging
import re
import uuid
from datetime import datetime
from typing import Any, Literal
@ -8,6 +9,7 @@ from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.datastructures import MultiDict
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -57,6 +59,7 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
class AppListQuery(BaseModel):
@ -66,22 +69,19 @@ class AppListQuery(BaseModel):
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not isinstance(value, list):
raise ValueError("Unsupported tag_ids type.")
items = [str(item).strip() for item in value if item and str(item).strip()]
if not items:
return None
@ -91,6 +91,26 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
normalized: dict[str, str | list[str]] = {}
indexed_tag_ids: list[tuple[int, str]] = []
for key in query_args:
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
if match:
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
continue
value = query_args.get(key)
if value is not None:
normalized[key] = value
if indexed_tag_ids:
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
return normalized
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@ -455,7 +475,7 @@ class AppListApi(Resource):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
args_dict = args.model_dump()
# get app list
@ -692,6 +712,32 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")

View File

@ -60,7 +60,8 @@ _file_access_controller = DatabaseFileAccessController()
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -158,8 +159,13 @@ class WorkflowFeaturesPayload(BaseModel):
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
class WorkflowOnlineUsersQuery(BaseModel):
app_ids: str = Field(..., description="Comma-separated app IDs")
class WorkflowOnlineUsersPayload(BaseModel):
app_ids: list[str] = Field(default_factory=list, description="App IDs")
@field_validator("app_ids")
@classmethod
def normalize_app_ids(cls, app_ids: list[str]) -> list[str]:
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
class DraftWorkflowTriggerRunPayload(BaseModel):
@ -186,7 +192,7 @@ reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(WorkflowFeaturesPayload)
reg(WorkflowOnlineUsersQuery)
reg(WorkflowOnlineUsersPayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
@ -1384,19 +1390,19 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@console_ns.route("/apps/workflows/online-users")
class WorkflowOnlineUsersApi(Resource):
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
@console_ns.doc("get_workflow_online_users")
@console_ns.doc(description="Get workflow online users")
@setup_required
@login_required
@account_initialization_required
@marshal_with(online_user_list_fields)
def get(self):
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
def post(self):
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
app_ids = args.app_ids
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS:
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS} app_ids are allowed per request.")
if not app_ids:
return {"data": []}
@ -1404,13 +1410,24 @@ class WorkflowOnlineUsersApi(Resource):
_, current_tenant_id = current_account_with_tenant()
workflow_service = WorkflowService()
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
users_json_by_app_id: dict[str, Any] = {}
for start_index in range(0, len(ordered_accessible_app_ids), WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE):
app_id_batch = ordered_accessible_app_ids[
start_index : start_index + WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE
]
pipe = redis_client.pipeline(transaction=False)
for app_id in app_id_batch:
pipe.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
users_json_batch = pipe.execute()
for app_id, users_json in zip(app_id_batch, users_json_batch):
users_json_by_app_id[app_id] = users_json
results = []
for app_id in app_ids:
if app_id not in accessible_app_ids:
continue
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
for app_id in ordered_accessible_app_ids:
users_json = users_json_by_app_id.get(app_id, {})
users = []
for _, user_info_json in users_json.items():

View File

@ -38,6 +38,48 @@ class HitTestingPayload(BaseModel):
class DatasetsHitTestingBase:
@staticmethod
def _normalize_hit_testing_query(query: Any) -> str:
"""Return the user-visible query string from legacy and current response shapes."""
if isinstance(query, str):
return query
if isinstance(query, dict):
content = query.get("content")
if isinstance(content, str):
return content
raise ValueError("Invalid hit testing query response")
@staticmethod
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Coerce nullable collection fields into lists before response validation."""
if not isinstance(records, list):
return []
normalized_records: list[dict[str, Any]] = []
for record in records:
if not isinstance(record, dict):
continue
normalized_record = dict(record)
segment = normalized_record.get("segment")
if isinstance(segment, dict):
normalized_segment = dict(segment)
if normalized_segment.get("keywords") is None:
normalized_segment["keywords"] = []
normalized_record["segment"] = normalized_segment
if normalized_record.get("child_chunks") is None:
normalized_record["child_chunks"] = []
if normalized_record.get("files") is None:
normalized_record["files"] = []
normalized_records.append(normalized_record)
return normalized_records
@staticmethod
def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
@ -75,7 +117,12 @@ class DatasetsHitTestingBase:
attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
return {
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
marshal(response.get("records", []), hit_testing_record_fields)
),
}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:

View File

@ -8,10 +8,10 @@ from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@staticmethod
def _ensure_console_recipient_type(form: Form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
raise NotFoundError("form not found")
@setup_required
@login_required
@account_initialization_required
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."

View File

@ -32,7 +32,7 @@ class TagBindingPayload(BaseModel):
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
@ -152,41 +152,68 @@ class TagUpdateDeleteApi(Resource):
return "", 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
def _require_tag_binding_edit_permission() -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings")
class TagBindingCollectionApi(Resource):
"""Canonical collection resource for tag binding creation."""
@console_ns.doc("create_tag_binding")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
class TagBindingRemoveApi(Resource):
"""Batch resource for tag binding deletion."""
@console_ns.doc("remove_tag_bindings")
@console_ns.doc(description="Remove one or more tag bindings from a target.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
return _remove_tag_bindings()

View File

@ -8,6 +8,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import supported_language
@ -45,6 +46,8 @@ from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -322,9 +325,24 @@ class AccountAvatarApi(Resource):
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
avatar = args.avatar
avatar_url = file_helpers.get_signed_file_url(args.avatar)
if avatar.startswith(("http://", "https://")):
return {"avatar_url": avatar}
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
if upload_file is None:
raise NotFound("Avatar file not found")
if upload_file.tenant_id != current_tenant_id:
raise NotFound("Avatar file not found")
if upload_file.created_by_role != CreatorUserRole.ACCOUNT or upload_file.created_by != current_user.id:
raise NotFound("Avatar file not found")
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {"avatar_url": avatar_url}
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])

View File

@ -23,9 +23,11 @@ from .app import (
conversation,
file,
file_preview,
human_input_form,
message,
site,
workflow,
workflow_events,
)
from .dataset import (
dataset,
@ -50,6 +52,7 @@ __all__ = [
"file",
"file_preview",
"hit_testing",
"human_input_form",
"index",
"message",
"metadata",
@ -58,6 +61,7 @@ __all__ = [
"segment",
"site",
"workflow",
"workflow_events",
]
api.add_namespace(service_api_ns)

View File

@ -0,0 +1,137 @@
"""
Service API human input form endpoints.
This module exposes app-token authenticated APIs for fetching and submitting
paused human input forms in workflow/chatflow runs.
"""
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
# Keep app-token callers scoped to the public web-form surface; internal HITL
# routes must continue to flow through console-only authentication.
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
raise NotFound("Form not found")
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token
def get(self, app_model: App, form_token: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form submitted successfully",
400: "Bad request - invalid submission data",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, form_token: str):
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
recipient_type = form.recipient_type
if recipient_type is None:
logger.warning("Recipient type is None for form, form_id=%s", form.id)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=end_user.id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -0,0 +1,142 @@
"""
Service API workflow resume event stream endpoints.
"""
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@service_api_ns.route("/workflow/<string:task_id>/events")
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(
params={
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": (
"Whether to replay from persisted state snapshot, "
'specify `"true"` to include a status snapshot of executed nodes'
),
"continue_on_pause": (
"Whether to keep the stream open across workflow_paused events,"
'specify `"true"` to keep the stream open for `workflow_paused` events.'
),
}
)
@service_api_ns.doc(
responses={
200: "SSE event stream",
401: "Unauthorized - invalid API token",
404: "Workflow run not found",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise NotWorkflowAppError()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFound("Workflow run not found")
if workflow_run.created_by != end_user.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

View File

@ -2,7 +2,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel):
class TagUnbindingPayload(BaseModel):
tag_id: str
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
tag_ids: list[str] = Field(default_factory=list)
tag_id: str | None = None
target_id: str
@model_validator(mode="before")
@classmethod
def normalize_legacy_tag_id(cls, data: object) -> object:
if not isinstance(data, dict):
return data
if not data.get("tag_ids") and data.get("tag_id"):
return {**data, "tag_ids": [data["tag_id"]]}
return data
@model_validator(mode="after")
def validate_tag_ids(self) -> "TagUnbindingPayload":
if not self.tag_ids:
raise ValueError("Tag IDs is required.")
return self
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc("unbind_dataset_tags")
@service_api_ns.doc(description="Unbind tags from a dataset")
@service_api_ns.doc(
responses={
204: "Tag unbound successfully",
204: "Tags unbound successfully",
401: "Unauthorized - invalid API token",
403: "Forbidden - insufficient permissions",
}
@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204

View File

@ -468,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
"""Update a document from an uploaded file for canonical and deprecated routes."""
dataset_id_str = str(dataset_id)
tenant_id_str = str(tenant_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
args: dict[str, object] = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = dataset.chunk_structure or "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if "file" in request.files:
# save file info
file = request.files["file"]
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
"""Deprecated resource aliases for file document updates."""
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc("update_document_by_file_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
description=(
"Deprecated legacy alias for updating an existing document by uploading a file. "
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
)
)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
@ -487,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
args = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = dataset.chunk_structure or "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if "file" in request.files:
# save file info
file = request.files["file"]
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by file through the deprecated file-update aliases."""
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
@ -876,6 +886,22 @@ class DocumentApi(DatasetApiResource):
return response
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Document updated successfully",
401: "Unauthorized - invalid API token",
404: "Document not found",
}
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by file on the canonical document resource."""
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.doc("delete_document")
@service_api_ns.doc(description="Delete a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})

View File

@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
@ -28,11 +28,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
class HumanInputUploadTokenResponse(BaseModel):
upload_token: str
expires_at: int

View File

@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
) -> (
ChatbotAppBlockingResponse
| AdvancedChatPausedBlockingResponse
| Generator[ChatbotAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppBlockingResponse,
AdvancedChatPausedBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
StreamEvent,
)
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
paused_data = blocking_response.data.model_dump(mode="json")
return {
"event": StreamEvent.WORKFLOW_PAUSED.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
"workflow_run_id": blocking_response.data.workflow_run_id,
"data": paused_data,
}
response = {
"event": "message",
"event": StreamEvent.MESSAGE.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
return response

View File

@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowPauseStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
AdvancedChatPausedBlockingResponse,
Generator[ChatbotAppStreamResponse, None, None],
]:
"""
Process generate task pipeline.
:return:
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
"""
Process blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> AdvancedChatPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return AdvancedChatPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=paused_nodes,
reasons=reasons,
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -1,7 +1,9 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Union, cast
from pydantic import JsonValue
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
@classmethod
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
return cast(TBlockingResponse, response)
@classmethod
def convert(
@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, Any]] = {
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data: dict[str, Any] | None = None
data: dict[str, JsonValue] | None = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
},
)
responses: list[StreamResponse] = []

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
)
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
response: dict[str, Any] = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,

View File

@ -1,6 +1,7 @@
from collections.abc import Callable, Generator, Mapping
from collections.abc import Callable, Generator, Iterable, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.task_entities import StreamEvent
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
@ -26,6 +27,7 @@ class MessageGenerator:
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
@ -33,4 +35,5 @@ class MessageGenerator:
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
terminal_events=terminal_events,
)

View File

@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -59,7 +59,7 @@ def stream_topic_events(
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
if terminal_events is None:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:

View File

@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.model_dump()
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -42,12 +42,15 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
]:
"""
Process generate task pipeline.
:return:
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
"""
To blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
),
)
return response
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> WorkflowAppPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
created_at = int(runtime_state.start_at)
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=human_input_responses[-1].workflow_run_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
created_at=created_at,
finished_at=None,
paused_nodes=paused_nodes,
reasons=reasons,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:

View File

@ -1,12 +1,13 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
data: Data
class HumanInputRequiredPauseReasonPayload(BaseModel):
"""
Public pause-reason payload used by blocking responses when only
``human_input_required`` events are available.
"""
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@classmethod
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
return cls(
form_id=data.form_id,
node_id=data.node_id,
node_title=data.node_title,
form_content=data.form_content,
inputs=data.inputs,
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
@ -357,7 +392,7 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -414,7 +449,7 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -776,6 +811,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
"""
ChatbotAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
message_id: str
workflow_run_id: str
answer: str
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
status: WorkflowExecutionStatus
elapsed_time: float
total_tokens: int
total_steps: int
data: Data
class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
@ -821,6 +884,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
"""
WorkflowAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
workflow_id: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_at: int
finished_at: int | None
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(
self,
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@ -65,7 +84,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
),
enable_credentials_cache=True,
)
self.model_manager = model_manager
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@ -0,0 +1,41 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
@ -217,8 +215,8 @@ class LLMGenerator:
else:
# Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = {
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
"max_tokens": 2560,
"temperature": 0.0,
}
stop = []

View File

@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def __init__(self, instruction_prompt: str | None = None) -> None:
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
@staticmethod
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
if not instruction_prompt or not instruction_prompt.strip():
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
def get_format_instructions(self) -> str:
return self._instruction_prompt

View File

@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
@ -36,11 +37,13 @@ class ModelInstance:
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager:
def __init__(self, provider_manager: ProviderManager):
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@ -151,6 +151,12 @@ def deserialize_response(raw_data: bytes) -> Response:
response = Response(response=body, status=status_code)
# Replace Flask's default headers (e.g. Content-Type, Content-Length) with the
# parsed ones so we faithfully reproduce the original response. Use Headers.add
# rather than dict-style assignment so that repeated headers such as Set-Cookie
# (and any other multi-valued header per RFC 9110) are preserved instead of
# being overwritten.
response.headers.clear()
for line in lines[1:]:
if not line:
continue
@ -158,6 +164,6 @@ def deserialize_response(raw_data: bytes) -> Response:
if ":" not in line_str:
continue
name, value = line_str.split(":", 1)
response.headers[name] = value.strip()
response.headers.add(name, value.strip())
return response

View File

@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@ -445,7 +445,7 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
provider_name_to_provider_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
@ -462,7 +462,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
@ -478,7 +478,7 @@ class ProviderManager:
:return:
"""
provider_name_to_preferred_provider_type_records_dict = {}
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
preferred_provider_types = session.scalars(stmt)
provider_name_to_preferred_provider_type_records_dict = {
@ -496,7 +496,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_settings_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
@ -514,7 +514,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_credentials_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
provider_model_credentials = session.scalars(stmt)
for provider_model_credential in provider_model_credentials:
@ -544,7 +544,7 @@ class ProviderManager:
return {}
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_configs = session.scalars(stmt)
for provider_load_balancing_config in provider_load_balancing_configs:
@ -578,7 +578,7 @@ class ProviderManager:
:param provider_name: provider name
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderCredential)
.where(
@ -608,7 +608,7 @@ class ProviderManager:
:param model_type: model type
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderModelCredential)
.where(

View File

@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"])
data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(

View File

@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
topK=max_keywords_per_chunk or 10,
)
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)

View File

@ -217,10 +217,11 @@ class RetrievalService:
"""Deduplicate documents in O(n) while preserving first-seen order.
Rules:
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
- For non-dify documents (or dify without doc_id): deduplicate by content key
(provider, page_content), keeping the first occurrence.
- If metadata["doc_id"] exists (any provider): deduplicate by (provider, doc_id) key;
keep the doc with the highest metadata["score"] among duplicates. If a later duplicate
has no score, ignore it.
- If metadata["doc_id"] is absent: deduplicate by content key (provider, page_content),
keeping the first occurrence.
"""
if not documents:
return documents
@ -231,11 +232,10 @@ class RetrievalService:
order: list[tuple] = []
for doc in documents:
is_dify = doc.provider == "dify"
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
doc_id = (doc.metadata or {}).get("doc_id")
if is_dify and doc_id:
key = ("dify", doc_id)
if doc_id:
key = (doc.provider or "dify", doc_id)
if key not in chosen:
chosen[key] = doc
order.append(key)
@ -551,6 +551,7 @@ class RetrievalService:
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
assert i.index_node_id
segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)

View File

@ -39,6 +39,58 @@ class AbstractVectorFactory(ABC):
return index_struct_dict
class _LazyEmbeddings(Embeddings):
"""Lazy proxy that defers materializing the real embedding model.
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
transitively calls ``FeatureService.get_features`` ``BillingService``
HTTP GETs (see ``provider_manager.py``). Cleanup paths
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
at all, so deferring this until an ``embed_*`` method is actually invoked
keeps cleanup tasks resilient to transient billing-API failures and avoids
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
hiccups.
Existing callers that perform create / search operations are unaffected:
the first ``embed_*`` call materializes the underlying model and the
behavior is identical from that point on.
"""
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._real: Embeddings | None = None
def _ensure(self) -> Embeddings:
if self._real is None:
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model,
)
self._real = CacheEmbedding(embedding_model)
return self._real
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._ensure().embed_documents(texts)
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
return self._ensure().embed_multimodal_documents(multimodel_documents)
def embed_query(self, text: str) -> list[float]:
return self._ensure().embed_query(text)
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
return self._ensure().embed_multimodal_query(multimodel_document)
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return await self._ensure().aembed_documents(texts)
async def aembed_query(self, text: str) -> list[float]:
return await self._ensure().aembed_query(text)
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
@ -60,7 +112,11 @@ class Vector:
"original_chunk_id",
]
self._dataset = dataset
self._embeddings = self._get_embeddings()
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
# never transitively trigger billing API calls during ``Vector(dataset)``
# construction. The real embedding model is materialized only when an
# ``embed_*`` method is actually invoked (i.e. create / search paths).
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
self._attributes = attributes
self._vector_processor = self._init_vector()
@ -88,8 +144,20 @@ class Vector:
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
return get_vector_factory_class(vector_type)
@staticmethod
def _filter_empty_text_documents(documents: list[Document]) -> list[Document]:
filtered_documents = [document for document in documents if document.page_content.strip()]
skipped_count = len(documents) - len(filtered_documents)
if skipped_count:
logger.warning("skip %d empty documents before vector embedding", skipped_count)
return filtered_documents
def create(self, texts: list | None = None, **kwargs):
if texts:
texts = self._filter_empty_text_documents(texts)
if not texts:
return
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)
batch_size = 1000
@ -147,8 +215,14 @@ class Vector:
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
documents = self._filter_empty_text_documents(documents)
if not documents:
return
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
if not documents:
return
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)

View File

@ -11,6 +11,7 @@ from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.enums import SegmentType
class DatasetDocumentStore:
@ -127,6 +128,7 @@ class DatasetDocumentStore:
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -137,7 +139,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)
@ -163,6 +165,7 @@ class DatasetDocumentStore:
)
# add new child chunks
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -173,7 +176,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)

View File

@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
stream=False, # pyright: ignore[reportArgumentType]
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()

View File

@ -1078,6 +1078,13 @@ class ToolManager:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
if variable_pool:
config = tool_configurations.get(parameter.name, {})
selector_value = cls._extract_runtime_selector_value(parameter, config)
if selector_value is not None:
# Selector parameters carry structured dictionaries, not scalar ToolInput values.
runtime_parameters[parameter.name] = selector_value
continue
if not (config and isinstance(config, dict) and config.get("value") is not None):
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
@ -1105,5 +1112,39 @@ class ToolManager:
runtime_parameters[parameter.name] = value
return runtime_parameters
@classmethod
def _extract_runtime_selector_value(cls, parameter: ToolParameter, config: Any) -> dict[str, Any] | None:
if parameter.type not in {
ToolParameter.ToolParameterType.MODEL_SELECTOR,
ToolParameter.ToolParameterType.APP_SELECTOR,
}:
return None
if not isinstance(config, dict):
return None
input_value = config.get("value")
if isinstance(input_value, dict) and cls._is_selector_value(parameter, input_value):
return cast("dict[str, Any]", parameter.init_frontend_parameter(input_value))
if cls._is_selector_value(parameter, config):
selector_value = dict(config)
selector_value.pop("type", None)
selector_value.pop("value", None)
return cast("dict[str, Any]", parameter.init_frontend_parameter(selector_value))
return None
@classmethod
def _is_selector_value(cls, parameter: ToolParameter, value: Mapping[str, Any]) -> bool:
if parameter.type == ToolParameter.ToolParameterType.MODEL_SELECTOR:
return (
isinstance(value.get("provider"), str)
and isinstance(value.get("model"), str)
and isinstance(value.get("model_type"), str)
)
if parameter.type == ToolParameter.ToolParameterType.APP_SELECTOR:
return isinstance(value.get("app_id"), str)
return False
ToolManager.load_hardcoded_providers_cache()

View File

@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
class SystemEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Encrypt parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
return params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an OAuth encrypter instance.
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
_encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global OAuth encrypter instance.
Get the global encrypter instance.
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Encrypt parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
params: Parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -272,6 +272,14 @@ def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, A
normalized_tool_configurations[name] = value
continue
selector_value = _extract_selector_configuration(value)
if selector_value is not None:
# Model/app selectors are dictionaries even when they come through the legacy tool configuration path.
# Move them to tool_parameters so graph validation does not flatten them as primitive constants.
found_legacy_tool_inputs = True
normalized_tool_parameters.setdefault(name, {"type": "constant", "value": selector_value})
continue
input_type = value.get("type")
input_value = value.get("value")
if input_type not in {"mixed", "variable", "constant"}:
@ -310,6 +318,28 @@ def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: An
return None
def _extract_selector_configuration(value: Mapping[str, Any]) -> dict[str, Any] | None:
input_value = value.get("value")
if isinstance(input_value, Mapping) and _is_selector_configuration(input_value):
return dict(input_value)
if _is_selector_configuration(value):
selector_value = dict(value)
selector_value.pop("type", None)
selector_value.pop("value", None)
return selector_value
return None
def _is_selector_configuration(value: Mapping[str, Any]) -> bool:
return (
isinstance(value.get("provider"), str)
and isinstance(value.get("model"), str)
and isinstance(value.get("model_type"), str)
) or isinstance(value.get("app_id"), str)
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)

View File

@ -12,20 +12,16 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
from extensions.ext_database import db
from models.human_input import HumanInputFormRecipient, RecipientType
_FORM_TOKEN_PRIORITY = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def load_form_tokens_by_form_id(
form_ids: Sequence[str],
*,
session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
"""Load the preferred access token for each human input form."""
unique_form_ids = list(dict.fromkeys(form_ids))
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
return {}
if session is not None:
return _load_form_tokens_by_form_id(session, unique_form_ids)
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
tokens_by_form_id: dict[str, tuple[int, str]] = {}
def _load_form_tokens_by_form_id(
session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt):
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
if priority is None or not recipient.access_token:
if not recipient.access_token:
continue
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token)
)
candidate = (priority, recipient.access_token)
current = tokens_by_form_id.get(recipient.form_id)
if current is None or candidate[0] < current[0]:
tokens_by_form_id[recipient.form_id] = candidate
tokens_by_form_id: dict[str, str] = {}
for form_id, recipients in recipients_by_form_id.items():
token = _get_surface_form_token(recipients, surface=surface)
if token is not None:
tokens_by_form_id[form_id] = token
return tokens_by_form_id
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
def _get_surface_form_token(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface == HumanInputSurface.SERVICE_API:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token
return get_preferred_form_token(recipients)

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from graphon.entities.pause_reason import PauseReasonType
from models.human_input import RecipientType
class HumanInputSurface(StrEnum):
SERVICE_API = "service_api"
CONSOLE = "console"
# Service API is intentionally narrower than other surfaces: app-token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
}
# A single HITL form can have multiple recipient records; this shared priority
# keeps every API surface consistent about which resume token to expose.
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def is_recipient_type_allowed_for_surface(
recipient_type: RecipientType | None,
surface: HumanInputSurface,
) -> bool:
if recipient_type is None:
return False
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
def get_preferred_form_token(
recipients: Sequence[tuple[RecipientType, str]],
) -> str | None:
chosen_token: str | None = None
chosen_priority: int | None = None
for recipient_type, token in recipients:
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
if priority is None or not token:
continue
if chosen_priority is None or priority < chosen_priority:
chosen_priority = priority
chosen_token = token
return chosen_token
def enrich_human_input_pause_reasons(
reasons: Sequence[Mapping[str, Any]],
*,
form_tokens_by_form_id: Mapping[str, str],
expiration_times_by_form_id: Mapping[str, int],
) -> list[dict[str, Any]]:
enriched: list[dict[str, Any]] = []
for reason in reasons:
updated = dict(reason)
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_id = updated.get("form_id")
if isinstance(form_id, str):
updated["form_token"] = form_tokens_by_form_id.get(form_id)
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is not None:
updated["expiration_time"] = expiration_time
enriched.append(updated)
return enriched

View File

@ -365,7 +365,8 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
adapted_node_config = adapt_node_config_for_graph(node_config)
typed_node_config = NodeConfigDictAdapter.validate_python(adapted_node_config)
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
@ -373,6 +374,11 @@ class DifyNodeFactory(NodeFactory):
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
config_for_node_init: BaseNodeData | dict[str, Any]
if isinstance(resolved_node_data, BaseNodeData):
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
else:
config_for_node_init = resolved_node_data
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@ -442,7 +448,7 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
node_id=node_id,
config=resolved_node_data,
config=config_for_node_init,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@ -474,10 +480,7 @@ class DifyNodeFactory(NodeFactory):
include_retriever_attachment_loader: bool,
include_jinja2_template_renderer: bool,
) -> dict[str, object]:
validated_node_data = cast(
LLMCompatibleNodeData,
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
)
validated_node_data = cast(LLMCompatibleNodeData, node_data)
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,

View File

@ -506,11 +506,15 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
@staticmethod
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
tool_configurations = dict(node_data.tool_configurations)
tool_configurations.update(
{name: tool_input.model_dump(mode="python") for name, tool_input in node_data.tool_parameters.items()}
)
return _WorkflowToolRuntimeSpec(
provider_type=CoreToolProviderType(node_data.provider_type.value),
provider_id=node_data.provider_id,
tool_name=node_data.tool_name,
tool_configurations=dict(node_data.tool_configurations),
tool_configurations=tool_configurations,
credential_id=node_data.credential_id,
)

View File

@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -3,6 +3,7 @@ import logging
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from events.app_event import app_draft_workflow_was_synced
from graphon.nodes import BuiltinNodeTypes
from graphon.nodes.tool.entities import ToolEntity
@ -19,7 +20,8 @@ def handle(sender, **kwargs):
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL:
try:
tool_entity = ToolEntity.model_validate(node_data["data"])
adapted_node_data = adapt_node_config_for_graph(node_data)
tool_entity = ToolEntity.model_validate(adapted_node_data["data"])
provider_type = ToolProviderType(tool_entity.provider_type.value)
tool_runtime = ToolManager.get_tool_runtime(
provider_type=provider_type,

View File

@ -1,7 +1,9 @@
from flask import Flask
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
def init_app(app):
def init_app(app: Flask):
with app.app_context():
configure_session_factory(db.engine)

View File

@ -298,7 +298,7 @@ def _build_from_datasource_file(
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
detected_file_type = standardize_file_type(extension=extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),

View File

@ -19,8 +19,13 @@ from werkzeug.http import parse_options_header
from core.helper import ssrf_proxy
def extract_filename(url_path: str, content_disposition: str | None) -> str | None:
"""Extract a safe filename from Content-Disposition or the request URL path."""
def extract_filename(url_or_path: str, content_disposition: str | None) -> str | None:
"""Extract a safe filename from Content-Disposition or the request URL path.
Handles full URLs, paths with query strings, hash fragments, and percent-encoded segments.
Query strings and hash fragments are stripped from the URL before extracting the basename.
Percent-encoded characters in the path are decoded safely.
"""
filename: str | None = None
if content_disposition:
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
@ -47,8 +52,13 @@ def extract_filename(url_path: str, content_disposition: str | None) -> str | No
filename = urllib.parse.unquote(raw)
if not filename:
candidate = os.path.basename(url_path)
filename = urllib.parse.unquote(candidate) if candidate else None
# Parse the URL to extract just the path, stripping query strings and fragments
# This handles both full URLs and bare paths
parsed = urllib.parse.urlparse(url_or_path)
path = parsed.path
candidate = os.path.basename(path)
# Decode percent-encoded characters, with safe fallback for malformed input
filename = urllib.parse.unquote(candidate, errors="replace") if candidate else None
if filename:
filename = os.path.basename(filename)

View File

@ -1036,7 +1036,7 @@ class DocumentSegment(Base):
return attachment_list
class ChildChunk(Base):
class ChildChunk(TypeBase):
__tablename__ = "child_chunks"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -1046,29 +1046,42 @@ class ChildChunk(Base):
)
# initial fields
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
segment_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
content = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
DateTime,
nullable=False,
server_default=sa.func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
indexing_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255),
nullable=False,
server_default=sa.text("'automatic'"),
default=SegmentType.AUTOMATIC,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False)
@property
def dataset(self):

View File

@ -1867,15 +1867,18 @@ class MessageAnnotation(TypeBase):
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
StringUUID,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID)
question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@ -2179,7 +2182,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
return result
class UploadFile(Base):
class UploadFile(TypeBase):
__tablename__ = "upload_files"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
@ -2187,9 +2190,12 @@ class UploadFile(Base):
)
# NOTE: The `id` field is generated within the application to minimize extra roundtrips
# (especially when generating `source_url`).
# The `server_default` serves as a fallback mechanism.
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
# (especially when generating `source_url`) and keep model metadata portable across databases.
id: Mapped[str] = mapped_column(
StringUUID,
init=False,
default_factory=lambda: str(uuid4()),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
@ -2197,16 +2203,6 @@ class UploadFile(Base):
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
extension: Mapped[str] = mapped_column(String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
# Its value is derived from the `CreatorUserRole` enumeration.
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'account'"),
default=CreatorUserRole.ACCOUNT,
)
# The `created_by` field stores the ID of the entity that created this upload file.
#
# If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`.
@ -2225,10 +2221,18 @@ class UploadFile(Base):
# `used` may indicate whether the file has been utilized by another service.
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
# Its value is derived from the `CreatorUserRole` enumeration.
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'account'"),
default=CreatorUserRole.ACCOUNT,
)
# `used_by` may indicate the ID of the user who utilized this file.
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True, default=None)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
source_url: Mapped[str] = mapped_column(LongText, default="")
def __init__(

View File

@ -9,11 +9,11 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from core.db.session_factory import session_factory
from graphon.model_runtime.entities.model_entities import ModelType
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
from .types import EnumText, LongText, StringUUID
@ -82,7 +82,8 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
with session_factory.create_session() as session:
return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -145,9 +146,10 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
with session_factory.create_session() as session:
return session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property
def credential_name(self):

View File

@ -50,7 +50,7 @@ from libs.uuid_utils import uuidv7
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
if TYPE_CHECKING:
from .model import AppMode, UploadFile
from .model import AppMode
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
@ -63,6 +63,10 @@ from .account import Account
from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
# UploadFile uses TypeBase while workflow execution offload models use Base, so relationships
# must target the class object directly instead of relying on string lookup across registries.
from .model import UploadFile
from .types import EnumText, LongText, StringUUID
from .utils.file_input_compat import (
build_file_from_mapping_without_lookup,
@ -1096,8 +1100,6 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def _load_full_content(session: orm.Session, file_id: str, storage: Storage):
from .model import UploadFile
stmt = sa.select(UploadFile).where(UploadFile.id == file_id)
file = session.scalars(stmt).first()
assert file is not None, f"UploadFile with id {file_id} should exist but not"
@ -1191,10 +1193,11 @@ class WorkflowNodeExecutionOffload(Base):
)
file: Mapped[Optional["UploadFile"]] = orm.relationship(
UploadFile,
foreign_keys=[file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id",
primaryjoin=lambda: orm.foreign(WorkflowNodeExecutionOffload.file_id) == UploadFile.id,
)
@ -1565,12 +1568,14 @@ class WorkflowDraftVariable(Base):
),
)
# Relationship to WorkflowDraftVariableFile
# WorkflowDraftVariableFile uses TypeBase while WorkflowDraftVariable uses Base, so the relationship
# must resolve the class object lazily instead of relying on string lookup across registries.
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
lambda: WorkflowDraftVariableFile,
foreign_keys=[file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowDraftVariableFile.id == WorkflowDraftVariable.file_id",
primaryjoin=lambda: orm.foreign(WorkflowDraftVariable.file_id) == WorkflowDraftVariableFile.id,
)
# Cache for deserialized value
@ -1889,7 +1894,7 @@ class WorkflowDraftVariable(Base):
return self.last_edited_at is not None
class WorkflowDraftVariableFile(Base):
class WorkflowDraftVariableFile(TypeBase):
"""Stores metadata about files associated with large workflow draft variables.
This model acts as an intermediary between WorkflowDraftVariable and UploadFile,
@ -1903,18 +1908,7 @@ class WorkflowDraftVariableFile(Base):
__tablename__ = "workflow_draft_variable_files"
# Primary key
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
default=lambda: str(uuidv7()),
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=naive_utc_now,
server_default=func.current_timestamp(),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default_factory=lambda: str(uuidv7()), init=False)
tenant_id: Mapped[str] = mapped_column(
StringUUID,
@ -1966,12 +1960,21 @@ class WorkflowDraftVariableFile(Base):
nullable=False,
)
# Relationship to UploadFile
# Rows are created with `upload_file_id`; callers should load this relationship explicitly when needed.
upload_file: Mapped["UploadFile"] = orm.relationship(
UploadFile,
foreign_keys=[upload_file_id],
lazy="raise",
init=False,
uselist=False,
primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id",
primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id,
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default_factory=naive_utc_now,
server_default=func.current_timestamp(),
)

View File

@ -225,8 +225,10 @@ class TestSpanBuilder:
span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan)
assert span.name == "test-span"
assert span.context is not None
assert span.context.trace_id == 123
assert span.context.span_id == 456
assert span.parent is not None
assert span.parent.span_id == 789
assert span.resource == resource
assert span.attributes == {"attr1": "val1"}

View File

@ -64,12 +64,13 @@ class TestSpanData:
def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError):
SpanData(
trace_id=123,
# span_id missing
name="test_span",
start_time=1000,
end_time=2000,
SpanData.model_validate(
{
"trace_id": 123,
"name": "test_span",
"start_time": 1000,
"end_time": 2000,
}
)
def test_span_data_arbitrary_types_allowed(self):

View File

@ -2,12 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@ -44,7 +46,7 @@ class RecordingTraceClient:
self.endpoint = endpoint
self.added_spans: list[object] = []
def add_span(self, span) -> None:
def add_span(self, span: object) -> None:
self.added_spans.append(span)
def api_check(self) -> bool:
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags.SAMPLED,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(context)
def _make_trace_metadata(
trace_id: int = 1,
workflow_span_id: int = 2,
session_id: str = "s",
user_id: str = "u",
links: list[Link] | None = None,
) -> TraceMetadata:
return TraceMetadata(
trace_id=trace_id,
workflow_span_id=workflow_span_id,
session_id=session_id,
user_id=user_id,
links=[] if links is None else links,
)
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
return cast(RecordingTraceClient, trace_instance.trace_client)
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = {
"workflow_id": "workflow-id",
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1]
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
)
trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2
message_span, llm_span = trace_instance.trace_client.added_spans
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span, llm_span = spans
assert message_span.name == "message"
assert message_span.trace_id == 10
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
)
)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "my-tool"
assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool"
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = BuiltinNodeTypes.CODE
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "title"
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
trace_metadata = _make_trace_metadata(links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "my-tool"
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "retrieval"
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "llm"
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
# CASE 1: With message_id
trace_info = _make_workflow_trace_info(
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
)
trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2
message_span = trace_instance.trace_client.added_spans[0]
workflow_span = trace_instance.trace_client.added_spans[1]
client = _recording_trace_client(trace_instance)
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span = spans[0]
workflow_span = spans[1]
assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear()
client.added_spans.clear()
# CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@ -1,4 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
from dify_trace_aliyun.entities.semconv import (
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
def test_format_retrieval_documents():
# Not a list
assert format_retrieval_documents("not a list") == []
assert format_retrieval_documents(cast(list[object], "not a list")) == []
# Valid list
docs = [
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
def test_format_input_messages():
# Not a dict
assert format_input_messages(None) == serialize_json_data([])
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No prompts
assert format_input_messages({}) == serialize_json_data([])
@ -244,7 +246,7 @@ def test_format_input_messages():
def test_format_output_messages():
# Not a dict
assert format_output_messages(None) == serialize_json_data([])
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])

View File

@ -25,13 +25,13 @@ class TestAliyunConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
AliyunConfig.model_validate({})
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
AliyunConfig.model_validate({"license_key": "test_license"})
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""

Some files were not shown because too many files have changed in this diff Show More