mirror of
https://github.com/langgenius/dify.git
synced 2026-06-16 05:51:07 +08:00
Merge remote-tracking branch 'origin/main' into deploy/dev
# Conflicts: # api/core/app/entities/task_entities.py # api/core/llm_generator/llm_generator.py # api/core/llm_generator/output_parser/suggested_questions_after_answer.py # api/core/llm_generator/prompts.py # api/models/comment.py # api/services/workflow_event_snapshot_service.py # api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py # api/tests/unit_tests/core/llm_generator/test_llm_generator.py # api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py # eslint-suppressions.json # web/app/components/app/app-publisher/index.tsx # web/app/components/datasets/settings/permission-selector/index.tsx # web/app/components/plugins/plugin-page/plugin-tasks/__tests__/index.spec.tsx # web/app/components/workflow/block-selector/blocks.tsx # web/app/components/workflow/block-selector/main.tsx # web/i18n/ar-TN/app-debug.json # web/i18n/de-DE/app-debug.json # web/i18n/en-US/workflow.json # web/i18n/es-ES/app-debug.json # web/i18n/fa-IR/app-debug.json # web/i18n/fr-FR/app-debug.json # web/i18n/hi-IN/app-debug.json # web/i18n/id-ID/app-debug.json # web/i18n/it-IT/app-debug.json # web/i18n/ja-JP/app-debug.json # web/i18n/ko-KR/app-debug.json # web/i18n/nl-NL/app-debug.json # web/i18n/pl-PL/app-debug.json # web/i18n/pt-BR/app-debug.json # web/i18n/ro-RO/app-debug.json # web/i18n/ru-RU/app-debug.json # web/i18n/sl-SI/app-debug.json # web/i18n/th-TH/app-debug.json # web/i18n/tr-TR/app-debug.json # web/i18n/uk-UA/app-debug.json # web/i18n/vi-VN/app-debug.json # web/i18n/zh-Hans/workflow.json # web/i18n/zh-Hant/app-debug.json
This commit is contained in:
commit
d92946e241
@ -367,7 +367,7 @@ For each extraction:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Extract code │
|
||||
│ 2. Run: pnpm lint:fix │
|
||||
│ 3. Run: pnpm type-check:tsgo │
|
||||
│ 3. Run: pnpm type-check │
|
||||
│ 4. Run: pnpm test │
|
||||
│ 5. Test functionality manually │
|
||||
│ 6. PASS? → Next extraction │
|
||||
|
||||
@ -127,7 +127,7 @@ For the current file being tested:
|
||||
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test:coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
- [ ] Run `pnpm type-check`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
|
||||
19
.github/workflows/anti-slop.yml
vendored
19
.github/workflows/anti-slop.yml
vendored
@ -1,19 +0,0 @@
|
||||
name: Anti-Slop PR Check
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, edited, synchronize]
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
anti-slop:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
failure-add-pr-labels: "needs-revision"
|
||||
8
.github/workflows/api-tests.yml
vendored
8
.github/workflows/api-tests.yml
vendored
@ -35,7 +35,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -84,7 +84,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -105,7 +105,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -156,7 +156,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
10
.github/workflows/autofix.yml
vendored
10
.github/workflows/autofix.yml
vendored
@ -25,7 +25,7 @@ jobs:
|
||||
- name: Check Docker Compose inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: docker-compose-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
@ -35,7 +35,7 @@ jobs:
|
||||
- name: Check web inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: web-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
@ -48,7 +48,7 @@ jobs:
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: api-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -58,7 +58,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@ -123,4 +123,4 @@ jobs:
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4
|
||||
|
||||
8
.github/workflows/db-migration-test.yml
vendored
8
.github/workflows/db-migration-test.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -40,7 +40,7 @@ jobs:
|
||||
cp middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -94,7 +94,7 @@ jobs:
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/pyrefly-type-coverage.yml
vendored
2
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
12
.github/workflows/style.yml
vendored
12
.github/workflows/style.yml
vendored
@ -25,7 +25,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@ -73,7 +73,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
@ -93,7 +93,7 @@ jobs:
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
@ -122,7 +122,7 @@ jobs:
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
@ -140,7 +140,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
||||
2
.github/workflows/tool-test-sdks.yaml
vendored
2
.github/workflows/tool-test-sdks.yaml
vendored
@ -30,7 +30,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version: 22
|
||||
cache: ''
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
|
||||
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
4
.github/workflows/vdb-tests-full.yml
vendored
4
.github/workflows/vdb-tests-full.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -65,7 +65,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
4
.github/workflows/vdb-tests.yml
vendored
4
.github/workflows/vdb-tests.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
2
.github/workflows/web-e2e.yml
vendored
2
.github/workflows/web-e2e.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -236,6 +236,10 @@ scripts/stress-test/reports/
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# vitest browser mode attachments (failure screenshots, traces, etc.)
|
||||
.vitest-attachments/
|
||||
**/__screenshots__/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
@ -30,7 +30,7 @@ The codebase is split into:
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
|
||||
@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
|
||||
@ -101,3 +101,11 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
```
|
||||
uv run dev/generate_swagger_specs.py --output-dir openapi
|
||||
```
|
||||
|
||||
use https://jsontotable.org/openapi-to-typescript to convert to typescript
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Creators Platform integration
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable Creators Platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client ID for Creators Platform integration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -1405,6 +1426,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
CreatorsPlatformConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
|
||||
6
api/controllers/common/human_input.py
Normal file
6
api/controllers/common/human_input.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
@ -728,6 +728,32 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
@ -8,10 +8,10 @@ from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_recipient_type(form: Form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
|
||||
raise NotFoundError("form not found")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
# So we need to assert it manually.
|
||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||
|
||||
@ -595,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
# Default to the initial phase; any legacy/unexpected client input is
|
||||
# coerced back to `old_email` so we never trust the caller to declare
|
||||
# later phases without a verified predecessor token.
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# The token used to request a new-email code must come from the
|
||||
# old-email verification step. This prevents the bypass described
|
||||
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@ -620,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
phase=send_phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -655,12 +667,31 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Only advance tokens that were minted by the matching send-code step;
|
||||
# refuse tokens that have already progressed or lack a phase marker so
|
||||
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
|
||||
# is strictly enforced.
|
||||
phase_transitions = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
# Refresh token data by generating a new token that carries the
|
||||
# upgraded phase so later steps can check it.
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@ -690,13 +721,29 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
# Only tokens that completed both verification phases may be used to
|
||||
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
|
||||
# the initial send-code step could be replayed directly here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Bind the new email to the token that was mailed and verified, so a
|
||||
# verified token cannot be reused with a different `new_email` value.
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
# Revoke only after all checks pass so failed attempts don't burn a
|
||||
# legitimately verified token.
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@ -1,3 +1,11 @@
|
||||
"""Console workspace endpoint controllers.
|
||||
|
||||
This module exposes workspace-scoped plugin endpoint management APIs. The
|
||||
canonical write routes follow resource-oriented paths, while the historical
|
||||
verb-based aliases stay available as deprecated resources so OpenAPI metadata
|
||||
marks only the legacy paths as deprecated.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
@ -25,7 +33,12 @@ class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
class EndpointUpdatePayload(BaseModel):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class LegacyEndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
@ -76,6 +89,7 @@ register_schema_models(
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
LegacyEndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
@ -88,8 +102,60 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints")
|
||||
class EndpointCollectionApi(Resource):
|
||||
"""Canonical collection resource for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@ -104,22 +170,33 @@ class EndpointCreateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
return _create_endpoint()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class DeprecatedEndpointCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -190,10 +267,56 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.route("/workspaces/current/endpoints/<string:id>")
|
||||
class EndpointItemApi(Resource):
|
||||
"""Canonical item resource for endpoint updates and deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class DeprecatedEndpointDeleteApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for deleting a plugin endpoint. "
|
||||
"Use DELETE /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
@ -206,22 +329,23 @@ class EndpointDeleteApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
class DeprecatedEndpointUpdateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint updates."""
|
||||
|
||||
@console_ns.doc("update_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating a plugin endpoint. "
|
||||
"Use PATCH /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
@ -233,19 +357,8 @@ class EndpointUpdateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
|
||||
@ -23,9 +23,11 @@ from .app import (
|
||||
conversation,
|
||||
file,
|
||||
file_preview,
|
||||
human_input_form,
|
||||
message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
from .dataset import (
|
||||
dataset,
|
||||
@ -50,6 +52,7 @@ __all__ = [
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"message",
|
||||
"metadata",
|
||||
@ -58,6 +61,7 @@ __all__ = [
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
||||
api.add_namespace(service_api_ns)
|
||||
|
||||
137
api/controllers/service_api/app/human_input_form.py
Normal file
137
api/controllers/service_api/app/human_input_form.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""
|
||||
Service API human input form endpoints.
|
||||
|
||||
This module exposes app-token authenticated APIs for fetching and submitting
|
||||
paused human input forms in workflow/chatflow runs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
|
||||
# Keep app-token callers scoped to the public web-form surface; internal HITL
|
||||
# routes must continue to flow through console-only authentication.
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@service_api_ns.route("/form/human_input/<string:form_token>")
|
||||
class WorkflowHumanInputFormApi(Resource):
|
||||
@service_api_ns.doc("get_human_input_form")
|
||||
@service_api_ns.doc(description="Get a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, form_token: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@service_api_ns.doc("submit_human_input_form")
|
||||
@service_api_ns.doc(description="Submit a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form submitted successfully",
|
||||
400: "Bad request - invalid submission data",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, form_token: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%s", form.id)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_end_user_id=end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
142
api/controllers/service_api/app/workflow_events.py
Normal file
142
api/controllers/service_api/app/workflow_events.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Service API workflow resume event stream endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@service_api_ns.route("/workflow/<string:task_id>/events")
|
||||
class WorkflowEventsApi(Resource):
|
||||
"""Service API for getting workflow execution events after resume."""
|
||||
|
||||
@service_api_ns.doc("get_workflow_events")
|
||||
@service_api_ns.doc(description="Get workflow execution events stream after resume")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"task_id": "Workflow run ID",
|
||||
"user": "End user identifier (query param)",
|
||||
"include_state_snapshot": (
|
||||
"Whether to replay from persisted state snapshot, "
|
||||
'specify `"true"` to include a status snapshot of executed nodes'
|
||||
),
|
||||
"continue_on_pause": (
|
||||
"Whether to keep the stream open across workflow_paused events,"
|
||||
'specify `"true"` to keep the stream open for `workflow_paused` events.'
|
||||
),
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "SSE event stream",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Workflow run not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
@ -1,4 +1,12 @@
|
||||
"""Service API endpoints for dataset document management.
|
||||
|
||||
The canonical Service API paths use hyphenated route segments. Legacy underscore
|
||||
aliases remain registered for backward compatibility, but they must stay marked
|
||||
deprecated in generated API docs so clients migrate toward the canonical paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
@ -117,12 +125,137 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Create a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id_str,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
"""Resource for the canonical text document creation route."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@ -138,81 +271,43 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||
class DeprecatedDocumentAddByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document creation."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a new document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/document/create-by-text instead."
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text through the deprecated underscore alias."""
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text")
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
"""Resource for the canonical text document update route."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@ -229,62 +324,35 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||
class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document updates."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead."
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text through the deprecated underscore alias."""
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
||||
@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
|
||||
@ -39,7 +39,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -656,7 +660,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse
|
||||
| AdvancedChatPausedBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
|
||||
paused_data = blocking_response.data.model_dump(mode="json")
|
||||
return {
|
||||
"event": StreamEvent.WORKFLOW_PAUSED.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
"workflow_run_id": blocking_response.data.workflow_run_id,
|
||||
"data": paused_data,
|
||||
}
|
||||
|
||||
response = {
|
||||
"event": "message",
|
||||
"event": StreamEvent.MESSAGE.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
Generator[ChatbotAppStreamResponse, None, None],
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
status=stream_response.data.status,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
),
|
||||
)
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> AdvancedChatPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=paused_nodes,
|
||||
reasons=reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||
total_tokens=runtime_state.total_tokens,
|
||||
total_steps=runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
@ -12,8 +14,10 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppGenerateResponseConverter(ABC):
|
||||
_blocking_response_type: type[AppBlockingResponse]
|
||||
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
@classmethod
|
||||
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
|
||||
return cast(TBlockingResponse, response)
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
@ -21,7 +25,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -30,7 +34,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -40,12 +44,12 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -107,13 +111,13 @@ class AppGenerateResponseConverter(ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
@ -127,7 +131,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data: dict[str, Any] | None = None
|
||||
data: dict[str, JsonValue] | None = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -65,6 +65,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=(
|
||||
HumanInputSurface.SERVICE_API
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
pause_reasons = enrich_human_input_pause_reasons(
|
||||
pause_reasons,
|
||||
form_tokens_by_form_id=form_token_by_form_id,
|
||||
expiration_times_by_form_id={
|
||||
form_id: int(expiration_time.timestamp())
|
||||
for form_id, expiration_time in expiration_times_by_form_id.items()
|
||||
},
|
||||
)
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
response: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from models.model import AppMode
|
||||
@ -26,6 +27,7 @@ class MessageGenerator:
|
||||
idle_timeout=300,
|
||||
ping_interval: float = 10.0,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
@ -33,4 +35,5 @@ class MessageGenerator:
|
||||
idle_timeout=idle_timeout,
|
||||
ping_interval=ping_interval,
|
||||
on_subscribe=on_subscribe,
|
||||
terminal_events=terminal_events,
|
||||
)
|
||||
|
||||
@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -29,7 +29,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
WorkflowAppBlockingResponse
|
||||
| WorkflowAppPausedBlockingResponse
|
||||
| Generator[WorkflowAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -59,7 +59,7 @@ def stream_topic_events(
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if not terminal_events:
|
||||
if terminal_events is None:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
|
||||
@ -29,7 +29,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
@ -633,7 +637,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
WorkflowAppBlockingResponse
|
||||
| WorkflowAppPausedBlockingResponse
|
||||
| Generator[WorkflowAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return blocking_response.model_dump()
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -45,12 +45,15 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
|
||||
"""
|
||||
To blocking response.
|
||||
:return:
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
return WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> WorkflowAppPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
created_at = int(runtime_state.start_at)
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id=human_input_responses[-1].workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||
total_tokens=runtime_state.total_tokens,
|
||||
total_steps=runtime_state.node_run_steps,
|
||||
created_at=created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=paused_nodes,
|
||||
reasons=reasons,
|
||||
),
|
||||
)
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
"""
|
||||
Public pause-reason payload used by blocking responses when only
|
||||
``human_input_required`` events are available.
|
||||
"""
|
||||
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
@classmethod
|
||||
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
|
||||
return cls(
|
||||
form_id=data.form_id,
|
||||
node_id=data.node_id,
|
||||
node_title=data.node_title,
|
||||
form_content=data.form_content,
|
||||
inputs=data.inputs,
|
||||
actions=data.actions,
|
||||
display_in_ui=data.display_in_ui,
|
||||
form_token=data.form_token,
|
||||
resolved_default_values=data.resolved_default_values,
|
||||
expiration_time=data.expiration_time,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppPausedBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
mode: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
workflow_run_id: str
|
||||
answer: str
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
|
||||
status: WorkflowExecutionStatus
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
|
||||
data: Data
|
||||
|
||||
|
||||
class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
CompletionAppBlockingResponse entity
|
||||
@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
WorkflowAppPausedBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_at: int
|
||||
finished_at: int | None
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentLogStreamResponse entity
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator # Changed from Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
@ -15,8 +16,21 @@ from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
"""Resolves and returns LLM credentials for a given provider and model.
|
||||
|
||||
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||
Because of that cache, a single instance can return stale credentials after
|
||||
the tenant or provider configuration changes (e.g. API key rotation).
|
||||
|
||||
Do **not** keep one instance for the lifetime of a process or across
|
||||
unrelated invocations. Create a new provider per request, workflow run, or
|
||||
other bounded scope where up-to-date credentials matter.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -31,8 +45,12 @@ class DifyCredentialsProvider:
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
self.credentials_cache = {}
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
if (provider_name, model_name) in self.credentials_cache:
|
||||
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
@ -47,6 +65,7 @@ class DifyCredentialsProvider:
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -66,7 +85,8 @@ class DifyModelFactory:
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
),
|
||||
enable_credentials_cache=True,
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -85,7 +105,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
|
||||
41
api/core/helper/creators.py
Normal file
41
api/core/helper/creators.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
def __init__(self, instruction_prompt: str | None = None) -> None:
|
||||
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
|
||||
|
||||
@staticmethod
|
||||
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
|
||||
if not instruction_prompt or not instruction_prompt.strip():
|
||||
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self._instruction_prompt
|
||||
|
||||
@ -107,6 +107,7 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
|
||||
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
|
||||
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@ -36,11 +37,13 @@ class ModelInstance:
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
if credentials is None:
|
||||
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.credentials = credentials
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
@ -434,8 +437,30 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||
|
||||
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||
``(tenant_id, provider, model_type, model)`` are stored in
|
||||
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||
API keys or provider settings change, so a manager constructed with
|
||||
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||
process or shared across unrelated work. Prefer a new manager per request,
|
||||
workflow run, or similar bounded scope.
|
||||
|
||||
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||
credential cache is not populated, and each ``get_model_instance`` call
|
||||
loads credentials from the current provider configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
*,
|
||||
enable_credentials_cache: bool = False,
|
||||
) -> None:
|
||||
self._provider_manager = provider_manager
|
||||
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||
self._enable_credentials_cache = enable_credentials_cache
|
||||
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
@ -463,8 +488,19 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||
|
||||
if cred_cache_key in self._credentials_cache:
|
||||
return ModelInstance(
|
||||
provider_model_bundle,
|
||||
model,
|
||||
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||
)
|
||||
|
||||
ret = ModelInstance(provider_model_bundle, model)
|
||||
if self._enable_credentials_cache:
|
||||
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||
return ret
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -70,12 +70,32 @@ class ProviderManager:
|
||||
Request-bound managers may carry caller identity in that runtime, and the
|
||||
resulting ``ProviderConfiguration`` objects must reuse it for downstream
|
||||
model-type and schema lookups.
|
||||
|
||||
Configuration assembly is cached per manager instance so call chains that
|
||||
share one request-scoped manager can reuse the same provider graph instead
|
||||
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
|
||||
when a long-lived manager needs to observe writes performed within the same
|
||||
instance scope.
|
||||
"""
|
||||
|
||||
decoding_rsa_key: Any | None
|
||||
decoding_cipher_rsa: Any | None
|
||||
_model_runtime: ModelRuntime
|
||||
_configurations_cache: dict[str, ProviderConfigurations]
|
||||
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
self._model_runtime = model_runtime
|
||||
self._configurations_cache = {}
|
||||
|
||||
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
|
||||
"""Drop assembled provider configurations cached on this manager instance."""
|
||||
if tenant_id is None:
|
||||
self._configurations_cache.clear()
|
||||
return
|
||||
|
||||
self._configurations_cache.pop(tenant_id, None)
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@ -114,6 +134,10 @@ class ProviderManager:
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
cached_configurations = self._configurations_cache.get(tenant_id)
|
||||
if cached_configurations is not None:
|
||||
return cached_configurations
|
||||
|
||||
# Get all provider records of the workspace
|
||||
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
|
||||
|
||||
@ -273,6 +297,8 @@ class ProviderManager:
|
||||
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
self._configurations_cache[tenant_id] = provider_configurations
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
|
||||
@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
}
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
|
||||
if keyword_data_source_type == "database":
|
||||
if dataset_keyword_table is None:
|
||||
return
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -154,7 +156,8 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
data: Any = keyword_table_dict["__data__"]
|
||||
return dict(data["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from operator import itemgetter
|
||||
from typing import cast
|
||||
|
||||
@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||
top_k = kwargs.pop("topK", top_k)
|
||||
top_k = cast(int | None, kwargs.pop("topK", top_k))
|
||||
if top_k is None:
|
||||
top_k = 20
|
||||
cut = getattr(jieba, "cut", None)
|
||||
if self._lcut:
|
||||
tokens = self._lcut(sentence)
|
||||
elif callable(cut):
|
||||
tokens = list(cut(sentence))
|
||||
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
|
||||
else:
|
||||
tokens = re.findall(r"\w+", sentence)
|
||||
|
||||
@ -106,9 +109,9 @@ class JiebaKeywordTableHandler:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
topK=max_keywords_per_chunk or 10,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
|
||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
||||
|
||||
@ -158,7 +158,7 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
if futures:
|
||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
for _ in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
if exceptions:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
|
||||
@ -94,6 +94,7 @@ class ExtractProcessor:
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
upload_file = extract_setting.upload_file
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
upload_file = extract_setting.upload_file
|
||||
if not file_path:
|
||||
@ -104,6 +105,7 @@ class ExtractProcessor:
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
assert upload_file is not None, "upload_file is required"
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
extractor: BaseExtractor | None = None
|
||||
if etl_type == "Unstructured":
|
||||
|
||||
@ -29,10 +29,10 @@ class FunctionCallMultiDatasetRouter:
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
stream=False, # pyright: ignore[reportArgumentType]
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from collections.abc import Collection
|
||||
from collections.abc import Set as AbstractSet
|
||||
from typing import Any, Literal
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
|
||||
cls: type[T],
|
||||
embedding_model_instance: ModelInstance | None,
|
||||
allowed_special: Literal["all"] | set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
_ = _token_encoder # kept for future token-length wiring
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: str | None = None,
|
||||
allowed_special: Literal["all"] | Set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
|
||||
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
|
||||
@ -14,23 +14,23 @@ from configs import dify_config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
class SystemEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
Initialize the encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
Encrypt parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
Decrypt parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
EncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
return params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
Create an encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
Get the global encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
Encrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
params: Parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
Decrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
@ -105,7 +105,7 @@ class Article:
|
||||
|
||||
|
||||
def extract_using_readabilipy(html: str):
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False)
|
||||
article = Article(
|
||||
title=json_article.get("title") or "",
|
||||
author=json_article.get("byline") or "",
|
||||
|
||||
@ -12,20 +12,16 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
_FORM_TOKEN_PRIORITY = {
|
||||
RecipientType.BACKSTAGE: 0,
|
||||
RecipientType.CONSOLE: 1,
|
||||
RecipientType.STANDALONE_WEB_APP: 2,
|
||||
}
|
||||
|
||||
|
||||
def load_form_tokens_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Load the preferred access token for each human input form."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
|
||||
tokens_by_form_id: dict[str, tuple[int, str]] = {}
|
||||
def _load_form_tokens_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
|
||||
if priority is None or not recipient.access_token:
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token)
|
||||
)
|
||||
|
||||
candidate = (priority, recipient.access_token)
|
||||
current = tokens_by_form_id.get(recipient.form_id)
|
||||
if current is None or candidate[0] < current[0]:
|
||||
tokens_by_form_id[recipient.form_id] = candidate
|
||||
tokens_by_form_id: dict[str, str] = {}
|
||||
for form_id, recipients in recipients_by_form_id.items():
|
||||
token = _get_surface_form_token(recipients, surface=surface)
|
||||
if token is not None:
|
||||
tokens_by_form_id[form_id] = token
|
||||
return tokens_by_form_id
|
||||
|
||||
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
|
||||
|
||||
def _get_surface_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface == HumanInputSurface.SERVICE_API:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
return get_preferred_form_token(recipients)
|
||||
|
||||
73
api/core/workflow/human_input_policy.py
Normal file
73
api/core/workflow/human_input_policy.py
Normal file
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
class HumanInputSurface(StrEnum):
|
||||
SERVICE_API = "service_api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
# Service API is intentionally narrower than other surfaces: app-token callers
|
||||
# should only be able to act on end-user web forms, not internal console flows.
|
||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||
}
|
||||
|
||||
# A single HITL form can have multiple recipient records; this shared priority
|
||||
# keeps every API surface consistent about which resume token to expose.
|
||||
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
|
||||
RecipientType.BACKSTAGE: 0,
|
||||
RecipientType.CONSOLE: 1,
|
||||
RecipientType.STANDALONE_WEB_APP: 2,
|
||||
}
|
||||
|
||||
|
||||
def is_recipient_type_allowed_for_surface(
|
||||
recipient_type: RecipientType | None,
|
||||
surface: HumanInputSurface,
|
||||
) -> bool:
|
||||
if recipient_type is None:
|
||||
return False
|
||||
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
|
||||
|
||||
def get_preferred_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
) -> str | None:
|
||||
chosen_token: str | None = None
|
||||
chosen_priority: int | None = None
|
||||
for recipient_type, token in recipients:
|
||||
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
|
||||
if priority is None or not token:
|
||||
continue
|
||||
if chosen_priority is None or priority < chosen_priority:
|
||||
chosen_priority = priority
|
||||
chosen_token = token
|
||||
return chosen_token
|
||||
|
||||
|
||||
def enrich_human_input_pause_reasons(
|
||||
reasons: Sequence[Mapping[str, Any]],
|
||||
*,
|
||||
form_tokens_by_form_id: Mapping[str, str],
|
||||
expiration_times_by_form_id: Mapping[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for reason in reasons:
|
||||
updated = dict(reason)
|
||||
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_id = updated.get("form_id")
|
||||
if isinstance(form_id, str):
|
||||
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||
if expiration_time is not None:
|
||||
updated["expiration_time"] = expiration_time
|
||||
enriched.append(updated)
|
||||
return enriched
|
||||
172
api/dev/generate_swagger_specs.py
Normal file
172
api/dev/generate_swagger_specs.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""Generate Flask-RESTX Swagger 2.0 specs without booting the full backend.
|
||||
|
||||
This helper intentionally avoids `app_factory.create_app()`. The normal backend
|
||||
startup eagerly initializes database, Redis, Celery, and storage extensions,
|
||||
which is unnecessary when the goal is only to serialize the Flask-RESTX
|
||||
`/swagger.json` documents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx.swagger import Swagger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(API_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(API_ROOT))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpecTarget:
|
||||
route: str
|
||||
filename: str
|
||||
|
||||
|
||||
SPEC_TARGETS: tuple[SpecTarget, ...] = (
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"),
|
||||
)
|
||||
|
||||
_ORIGINAL_REGISTER_MODEL = Swagger.register_model
|
||||
_ORIGINAL_REGISTER_FIELD = Swagger.register_field
|
||||
|
||||
|
||||
def _apply_runtime_defaults() -> None:
|
||||
"""Force the small config surface required for Swagger generation."""
|
||||
|
||||
os.environ.setdefault("SECRET_KEY", "spec-export")
|
||||
os.environ.setdefault("STORAGE_TYPE", "local")
|
||||
os.environ.setdefault("STORAGE_LOCAL_PATH", "/tmp/dify-storage")
|
||||
os.environ.setdefault("SWAGGER_UI_ENABLED", "true")
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
dify_config.SECRET_KEY = os.environ["SECRET_KEY"]
|
||||
dify_config.STORAGE_TYPE = "local"
|
||||
dify_config.STORAGE_LOCAL_PATH = os.environ["STORAGE_LOCAL_PATH"]
|
||||
dify_config.SWAGGER_UI_ENABLED = os.environ["SWAGGER_UI_ENABLED"].lower() == "true"
|
||||
|
||||
|
||||
def _patch_swagger_for_inline_nested_dicts() -> None:
|
||||
"""Teach Flask-RESTX Swagger generation to tolerate inline nested field maps.
|
||||
|
||||
Some existing controllers use `fields.Nested({...})` with a raw field mapping
|
||||
instead of a named `api.model(...)`. Flask-RESTX crashes on those anonymous
|
||||
dicts during schema registration, so this helper upgrades them into temporary
|
||||
named models at export time.
|
||||
"""
|
||||
|
||||
if getattr(Swagger, "_dify_inline_nested_dict_patch", False):
|
||||
return
|
||||
|
||||
def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object:
|
||||
anonymous_models = getattr(self, "_anonymous_inline_models", None)
|
||||
if anonymous_models is None:
|
||||
anonymous_models = {}
|
||||
self._anonymous_inline_models = anonymous_models
|
||||
|
||||
anonymous_name = anonymous_models.get(id(nested_fields))
|
||||
if anonymous_name is None:
|
||||
anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}"
|
||||
anonymous_models[id(nested_fields)] = anonymous_name
|
||||
self.api.model(anonymous_name, nested_fields)
|
||||
|
||||
return self.api.models[anonymous_name]
|
||||
|
||||
def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]:
|
||||
if isinstance(model, dict):
|
||||
model = get_or_create_inline_model(self, model)
|
||||
|
||||
return _ORIGINAL_REGISTER_MODEL(self, model)
|
||||
|
||||
def register_field_with_inline_dict_support(self: Swagger, field: object) -> None:
|
||||
nested = getattr(field, "nested", None)
|
||||
if isinstance(nested, dict):
|
||||
field.model = get_or_create_inline_model(self, nested) # type: ignore
|
||||
|
||||
_ORIGINAL_REGISTER_FIELD(self, field)
|
||||
|
||||
Swagger.register_model = register_model_with_inline_dict_support
|
||||
Swagger.register_field = register_field_with_inline_dict_support
|
||||
Swagger._dify_inline_nested_dict_patch = True
|
||||
|
||||
|
||||
def create_spec_app() -> Flask:
|
||||
"""Build a minimal Flask app that only mounts the Swagger-producing blueprints."""
|
||||
|
||||
_apply_runtime_defaults()
|
||||
_patch_swagger_for_inline_nested_dicts()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
from controllers.console import bp as console_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
app.register_blueprint(console_bp)
|
||||
app.register_blueprint(web_bp)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def generate_specs(output_dir: Path) -> list[Path]:
|
||||
"""Write all Swagger specs to `output_dir` and return the written paths."""
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
app = create_spec_app()
|
||||
client = app.test_client()
|
||||
|
||||
written_paths: list[Path] = []
|
||||
for target in SPEC_TARGETS:
|
||||
response = client.get(target.route)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}")
|
||||
|
||||
payload = response.get_json()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"unexpected response payload for {target.route}")
|
||||
|
||||
output_path = output_dir / target.filename
|
||||
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
written_paths.append(output_path)
|
||||
|
||||
return written_paths
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("openapi"),
|
||||
help="Directory where the Swagger JSON files will be written.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
written_paths = generate_specs(args.output_dir)
|
||||
|
||||
for path in written_paths:
|
||||
logger.debug(path)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -1,5 +1,5 @@
|
||||
import contextvars
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator # Changed from Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
def preserve_flask_contexts(
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> Iterator[None]:
|
||||
) -> Generator[None, None, None]: # Changed from Iterator[None]
|
||||
"""
|
||||
A context manager that handles:
|
||||
1. flask-login's UserProxy copy
|
||||
|
||||
@ -42,8 +42,7 @@ class WorkflowComment(Base):
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(sa.Float)
|
||||
@ -152,8 +151,7 @@ class WorkflowCommentReply(Base):
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
@ -200,8 +198,7 @@ class WorkflowCommentMention(Base):
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator # Added Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
@ -75,7 +75,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self) -> Iterator[Any]:
|
||||
def _get_cursor(self) -> Generator[Any, None, None]: # Changed from Iterator[Any]
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
|
||||
@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
auth = PasswordAuthenticator(config.user, config.password)
|
||||
options = ClusterOptions(auth)
|
||||
self._cluster = Cluster(config.connection_string, options)
|
||||
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
|
||||
self._bucket = self._cluster.bucket(config.bucket_name)
|
||||
self._scope = self._bucket.scope(config.scope_name)
|
||||
self._bucket_name = config.bucket_name
|
||||
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
|
||||
def _load_collection_fields(self, fields: list[str] | None = None):
|
||||
if fields is None:
|
||||
# Load collection fields from remote server
|
||||
collection_info = self._client.describe_collection(self._collection_name)
|
||||
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
|
||||
fields = [field["name"] for field in collection_info["fields"]]
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
||||
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
|
||||
return False
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
milvus_version_raw = self._client.get_server_version()
|
||||
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
|
||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||
if "Zilliz Cloud" in milvus_version:
|
||||
return True
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
||||
|
||||
class _OraclePoolParams(TypedDict, total=False):
|
||||
user: str
|
||||
password: str
|
||||
dsn: str
|
||||
min: int
|
||||
max: int
|
||||
increment: int
|
||||
config_dir: str | None
|
||||
wallet_location: str | None
|
||||
wallet_password: str | None
|
||||
|
||||
|
||||
class OracleVectorConfig(BaseModel):
|
||||
user: str
|
||||
password: str
|
||||
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
|
||||
return connection
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
pool_params = {
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"dsn": config.dsn,
|
||||
"min": 1,
|
||||
"max": 5,
|
||||
"increment": 1,
|
||||
}
|
||||
pool_params = _OraclePoolParams(
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
dsn=config.dsn,
|
||||
min=1,
|
||||
max=5,
|
||||
increment=1,
|
||||
)
|
||||
if config.is_autonomous:
|
||||
pool_params.update(
|
||||
{
|
||||
"config_dir": config.config_dir,
|
||||
"wallet_location": config.wallet_location,
|
||||
"wallet_password": config.wallet_password,
|
||||
}
|
||||
)
|
||||
pool_params["config_dir"] = config.config_dir
|
||||
pool_params["wallet_location"] = config.wallet_location
|
||||
pool_params["wallet_password"] = config.wallet_password
|
||||
return oracledb.create_pool(**pool_params)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
|
||||
@ -6,7 +6,7 @@ requires-python = "~=3.12.0"
|
||||
dependencies = [
|
||||
# Legacy: mature and widely deployed
|
||||
"bleach>=6.3.0",
|
||||
"boto3>=1.42.88",
|
||||
"boto3>=1.42.91",
|
||||
"celery>=5.6.3",
|
||||
"croniter>=6.2.2",
|
||||
"flask-cors>=6.0.2",
|
||||
@ -30,7 +30,7 @@ dependencies = [
|
||||
"flask-migrate>=4.1.0,<5.0.0",
|
||||
"flask-orjson>=2.0.0,<3.0.0",
|
||||
"flask-restx>=1.3.2,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.147.0,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.148.1,<2.0.0",
|
||||
"httpx[socks]>=0.28.1,<1.0.0",
|
||||
"opentelemetry-distro>=0.62b0,<1.0.0",
|
||||
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
|
||||
@ -46,7 +46,7 @@ dependencies = [
|
||||
"fastopenapi[flask]~=0.7.0",
|
||||
"graphon~=0.2.2",
|
||||
"httpx-sse~=0.4.0",
|
||||
"json-repair~=0.59.2",
|
||||
"json-repair~=0.59.4",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
@ -114,10 +114,10 @@ override-dependencies = [
|
||||
dev = [
|
||||
"coverage>=7.13.4",
|
||||
"dotenv-linter>=0.7.0",
|
||||
"faker>=20.1.0",
|
||||
"faker>=40.15.0",
|
||||
"lxml-stubs>=0.5.1",
|
||||
"basedpyright>=1.39.0",
|
||||
"ruff>=0.15.10",
|
||||
"basedpyright>=1.39.3",
|
||||
"ruff>=0.15.11",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-benchmark>=5.2.3",
|
||||
"pytest-cov>=7.1.0",
|
||||
@ -157,14 +157,14 @@ dev = [
|
||||
"types-tensorflow>=2.18.0.20260408",
|
||||
"types-tqdm>=4.67.3.20260408",
|
||||
"types-ujson>=5.10.0",
|
||||
"boto3-stubs>=1.42.88",
|
||||
"boto3-stubs>=1.42.92",
|
||||
"types-jmespath>=1.1.0.20260408",
|
||||
"hypothesis>=6.151.12",
|
||||
"hypothesis>=6.152.1",
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=2.0.0.20260408",
|
||||
"types_setuptools>=82.0.0.20260408",
|
||||
"pandas-stubs>=3.0.0",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"scipy-stubs>=1.17.1.4",
|
||||
"types-python-http-client>=3.3.7.20260408",
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
@ -173,8 +173,8 @@ dev = [
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.61.1",
|
||||
"xinference-client>=2.4.0",
|
||||
"pyrefly>=0.62.0",
|
||||
"xinference-client>=2.5.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
@ -183,13 +183,13 @@ dev = [
|
||||
############################################################
|
||||
storage = [
|
||||
"azure-storage-blob>=12.28.0",
|
||||
"bce-python-sdk>=0.9.69",
|
||||
"bce-python-sdk>=0.9.70",
|
||||
"cos-python-sdk-v5>=1.9.41",
|
||||
"esdk-obs-python>=3.22.2",
|
||||
"google-cloud-storage>=3.10.1",
|
||||
"opendal>=0.46.0",
|
||||
"oss2>=2.19.1",
|
||||
"supabase>=2.18.1",
|
||||
"supabase>=2.28.3",
|
||||
"tos>=2.9.0",
|
||||
]
|
||||
|
||||
@ -272,7 +272,7 @@ vdb-vastbase = ["dify-vdb-vastbase"]
|
||||
vdb-vikingdb = ["dify-vdb-vikingdb"]
|
||||
vdb-weaviate = ["dify-vdb-weaviate"]
|
||||
# Optional client used by some tests / integrations (not a vector backend plugin)
|
||||
vdb-xinference = ["xinference-client>=2.4.0"]
|
||||
vdb-xinference = ["xinference-client>=2.5.0"]
|
||||
|
||||
trace-all = [
|
||||
"dify-trace-aliyun",
|
||||
|
||||
@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.human_input import HumanInputForm
|
||||
from models.human_input import HumanInputForm, HumanInputFormRecipient
|
||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
|
||||
def _build_human_input_required_reason(
|
||||
reason_model: WorkflowPauseReason,
|
||||
form_model: HumanInputForm | None,
|
||||
recipients: Sequence[HumanInputFormRecipient] = (),
|
||||
) -> HumanInputRequired:
|
||||
form_content = ""
|
||||
inputs = []
|
||||
@ -89,7 +90,7 @@ def _build_human_input_required_reason(
|
||||
resolved_default_values = dict(definition.default_values)
|
||||
node_title = definition.node_title or node_title
|
||||
|
||||
return HumanInputRequired(
|
||||
reason = HumanInputRequired(
|
||||
form_id=form_id,
|
||||
form_content=form_content,
|
||||
inputs=inputs,
|
||||
@ -98,6 +99,7 @@ def _build_human_input_required_reason(
|
||||
node_title=node_title,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
return reason
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
|
||||
for form in session.scalars(form_stmt).all():
|
||||
form_models[form.id] = form
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
|
||||
if form_ids:
|
||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(recipient_stmt).all():
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
|
||||
|
||||
pause_reasons: list[PauseReason] = []
|
||||
for reason in pause_reason_models:
|
||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_model = form_models.get(reason.form_id)
|
||||
pause_reasons.append(_build_human_input_required_reason(reason, form_model))
|
||||
pause_reasons.append(
|
||||
_build_human_input_required_reason(
|
||||
reason,
|
||||
form_model,
|
||||
recipients_by_form_id.get(reason.form_id, ()),
|
||||
)
|
||||
)
|
||||
else:
|
||||
pause_reasons.append(reason.to_entity())
|
||||
return pause_reasons
|
||||
|
||||
@ -112,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
# Phase-bound token metadata for the change-email flow. Tokens carry the
|
||||
# current phase so that downstream endpoints can enforce proper progression
|
||||
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
|
||||
CHANGE_EMAIL_PHASE_OLD = "old_email"
|
||||
CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified"
|
||||
CHANGE_EMAIL_PHASE_NEW = "new_email"
|
||||
CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified"
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
@ -576,13 +584,20 @@ class AccountService:
|
||||
raise ValueError("Email must be provided.")
|
||||
if not phase:
|
||||
raise ValueError("phase must be provided.")
|
||||
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
|
||||
raise ValueError("phase must be one of old_email or new_email.")
|
||||
|
||||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
code, token = cls.generate_change_email_token(
|
||||
account_email,
|
||||
account,
|
||||
old_email=old_email,
|
||||
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
|
||||
)
|
||||
|
||||
send_change_mail_task.delay(
|
||||
language=language,
|
||||
|
||||
@ -164,6 +164,7 @@ class AppGenerateService:
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
@ -185,6 +186,10 @@ class AppGenerateService:
|
||||
else:
|
||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
@ -196,6 +201,7 @@ class AppGenerateService:
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
|
||||
@ -5,6 +5,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cachetools.func import ttl_cache
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
@ttl_cache(ttl=5)
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request("GET", "/info")
|
||||
|
||||
|
||||
@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
trial_models: list[str] = []
|
||||
enable_creators_platform: bool = False
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
|
||||
@ -241,6 +242,9 @@ class FeatureService:
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
system_features.enable_marketplace = True
|
||||
|
||||
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
system_features.enable_creators_platform = True
|
||||
|
||||
return system_features
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -2,7 +2,7 @@ import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Generator, Sequence # Changed Iterator to Generator
|
||||
from contextlib import contextmanager, suppress
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Literal
|
||||
@ -324,7 +324,7 @@ class FileService:
|
||||
def build_upload_files_zip_tempfile(
|
||||
*,
|
||||
upload_files: Sequence[UploadFile],
|
||||
) -> Iterator[str]:
|
||||
) -> Generator[str, None, None]: # Changed from Iterator[str]
|
||||
"""
|
||||
Build a ZIP from `UploadFile`s and yield a tempfile path.
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@ -36,6 +36,10 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
|
||||
metadata_filtering_conditions: dict[str, Any]
|
||||
|
||||
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(
|
||||
@ -51,17 +55,18 @@ class HitTestingService:
|
||||
start = time.perf_counter()
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
assert isinstance(retrieval_model, dict)
|
||||
resolved_retrieval_model = cast(
|
||||
HitTestingRetrievalModelDict,
|
||||
retrieval_model or dataset.retrieval_model or default_retrieval_model,
|
||||
)
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions and query:
|
||||
metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions_raw and query:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
|
||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw)
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=[dataset.id],
|
||||
@ -78,19 +83,21 @@ class HitTestingService:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
retrieval_method=RetrievalMethod(
|
||||
resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
|
||||
),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
top_k=resolved_retrieval_model.get("top_k", 4),
|
||||
score_threshold=resolved_retrieval_model.get("score_threshold", 0.0)
|
||||
if resolved_retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
reranking_model=resolved_retrieval_model.get("reranking_model", None)
|
||||
if resolved_retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=resolved_retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ToolProviderID
|
||||
@ -521,7 +521,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
@ -635,7 +635,7 @@ class TriggerProviderService:
|
||||
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.entities.task_entities import (
|
||||
HumanInputRequiredResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
@ -26,6 +27,10 @@ from core.app.entities.task_entities import (
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from models.human_input import HumanInputForm
|
||||
from models.model import AppMode, Message
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
@ -59,8 +64,10 @@ def build_workflow_event_stream(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
human_input_surface: HumanInputSurface | None = None,
|
||||
idle_timeout: float = 300,
|
||||
ping_interval: float = 10.0,
|
||||
close_on_pause: bool = True,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
@ -115,13 +122,15 @@ def build_workflow_event_stream(
|
||||
message_context=message_context,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=human_input_surface,
|
||||
)
|
||||
|
||||
for event in snapshot_events:
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event, include_paused=True):
|
||||
if _is_terminal_event(event, close_on_pause=close_on_pause):
|
||||
return
|
||||
|
||||
while True:
|
||||
@ -146,7 +155,7 @@ def build_workflow_event_stream(
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event, include_paused=True):
|
||||
if _is_terminal_event(event, close_on_pause=close_on_pause):
|
||||
return
|
||||
finally:
|
||||
buffer_state.stop_event.set()
|
||||
@ -207,6 +216,8 @@ def _build_snapshot_events(
|
||||
message_context: MessageContext | None,
|
||||
pause_entity: WorkflowPauseEntity | None,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
session_maker: sessionmaker[Session] | None = None,
|
||||
human_input_surface: HumanInputSurface | None = None,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
events: list[Mapping[str, Any]] = []
|
||||
|
||||
@ -241,12 +252,24 @@ def _build_snapshot_events(
|
||||
events.append(node_finished)
|
||||
|
||||
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
|
||||
for human_input_event in _build_human_input_required_events(
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
pause_entity=pause_entity,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=human_input_surface,
|
||||
):
|
||||
_apply_message_context(human_input_event, message_context)
|
||||
events.append(human_input_event)
|
||||
|
||||
pause_event = _build_pause_event(
|
||||
workflow_run=workflow_run,
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=human_input_surface,
|
||||
)
|
||||
if pause_event is not None:
|
||||
_apply_message_context(pause_event, message_context)
|
||||
@ -314,6 +337,97 @@ def _build_node_started_event(
|
||||
return response.to_ignore_detail_dict()
|
||||
|
||||
|
||||
def _build_human_input_required_events(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
session_maker: sessionmaker[Session] | None,
|
||||
human_input_surface: HumanInputSurface | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
|
||||
human_input_form_ids = [
|
||||
form_id
|
||||
for reason in reasons
|
||||
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
for form_id in [reason.get("form_id")]
|
||||
if isinstance(form_id, str)
|
||||
]
|
||||
|
||||
expiration_times_by_form_id: dict[str, int] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
form_tokens_by_form_id: dict[str, str] = {}
|
||||
if human_input_form_ids and session_maker is not None:
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
with session_maker() as session:
|
||||
for form_id, expiration_time, form_definition in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
|
||||
try:
|
||||
definition_payload = json.loads(form_definition) if form_definition else {}
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_tokens_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=human_input_surface,
|
||||
)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
for reason in reasons:
|
||||
if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
continue
|
||||
|
||||
form_id_raw = reason.get("form_id")
|
||||
node_id_raw = reason.get("node_id")
|
||||
node_title_raw = reason.get("node_title")
|
||||
form_content_raw = reason.get("form_content")
|
||||
if not isinstance(form_id_raw, str):
|
||||
continue
|
||||
if not isinstance(node_id_raw, str):
|
||||
continue
|
||||
if not isinstance(node_title_raw, str):
|
||||
continue
|
||||
if not isinstance(form_content_raw, str):
|
||||
continue
|
||||
form_id = form_id_raw
|
||||
node_id = node_id_raw
|
||||
node_title = node_title_raw
|
||||
form_content = form_content_raw
|
||||
|
||||
inputs = reason.get("inputs")
|
||||
actions = reason.get("actions")
|
||||
resolved_default_values = reason.get("resolved_default_values")
|
||||
|
||||
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||
if expiration_time is None:
|
||||
continue
|
||||
|
||||
response = HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id=form_id,
|
||||
node_id=node_id,
|
||||
node_title=node_title,
|
||||
form_content=form_content,
|
||||
inputs=inputs if isinstance(inputs, list) else [],
|
||||
actions=actions if isinstance(actions, list) else [],
|
||||
display_in_ui=display_in_ui_by_form_id.get(form_id, False),
|
||||
form_token=form_tokens_by_form_id.get(form_id),
|
||||
resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}),
|
||||
expiration_time=expiration_time,
|
||||
),
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
events.append(payload)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _build_node_finished_event(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
@ -356,6 +470,8 @@ def _build_pause_event(
|
||||
task_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
session_maker: sessionmaker[Session] | None,
|
||||
human_input_surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
paused_nodes: list[str] = []
|
||||
outputs: dict[str, Any] = {}
|
||||
@ -365,6 +481,36 @@ def _build_pause_event(
|
||||
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
|
||||
|
||||
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
|
||||
human_input_form_ids = [
|
||||
form_id
|
||||
for reason in reasons
|
||||
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
for form_id in [reason.get("form_id")]
|
||||
if isinstance(form_id, str)
|
||||
]
|
||||
form_tokens_by_form_id: dict[str, str] = {}
|
||||
expiration_times_by_form_id: dict[str, int] = {}
|
||||
if human_input_form_ids and session_maker is not None:
|
||||
with session_maker() as session:
|
||||
form_tokens_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=human_input_surface,
|
||||
)
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
for row in session.execute(stmt):
|
||||
form_id, expiration_time, *_rest = row
|
||||
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
reasons = enrich_human_input_pause_reasons(
|
||||
reasons,
|
||||
form_tokens_by_form_id=form_tokens_by_form_id,
|
||||
expiration_times_by_form_id=expiration_times_by_form_id,
|
||||
)
|
||||
|
||||
response = WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
|
||||
return event
|
||||
|
||||
|
||||
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
|
||||
def _is_terminal_event(
|
||||
event: Mapping[str, Any] | str,
|
||||
close_on_pause: bool = True,
|
||||
*,
|
||||
include_paused: bool | None = None,
|
||||
) -> bool:
|
||||
if include_paused is not None:
|
||||
close_on_pause = include_paused
|
||||
if not isinstance(event, Mapping):
|
||||
return False
|
||||
event_type = event.get("event")
|
||||
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
|
||||
return True
|
||||
if include_paused:
|
||||
if close_on_pause:
|
||||
return event_type == StreamEvent.WORKFLOW_PAUSED.value
|
||||
return False
|
||||
|
||||
@ -399,6 +399,8 @@ def _resume_advanced_chat(
|
||||
workflow_run_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> None:
|
||||
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
|
||||
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
@ -426,7 +428,7 @@ def _resume_advanced_chat(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
application_generate_entity=generate_entity,
|
||||
application_generate_entity=resumed_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@ -436,9 +438,8 @@ def _resume_advanced_chat(
|
||||
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
|
||||
|
||||
|
||||
def _resume_workflow(
|
||||
@ -455,6 +456,8 @@ def _resume_workflow(
|
||||
workflow_run_repo,
|
||||
pause_entity,
|
||||
) -> None:
|
||||
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
|
||||
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
@ -480,7 +483,7 @@ def _resume_workflow(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=generate_entity,
|
||||
application_generate_entity=resumed_generate_entity,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
@ -490,11 +493,18 @@ def _resume_workflow(
|
||||
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
if generate_entity.stream:
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
|
||||
assert isinstance(response, Generator)
|
||||
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
|
||||
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
try:
|
||||
workflow_run_repo.delete_workflow_pause(pause_entity)
|
||||
except Exception as exc:
|
||||
if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc):
|
||||
raise
|
||||
logger.info(
|
||||
"Skipped deleting workflow pause %s after resume because it was already replaced or removed",
|
||||
pause_entity.id,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")
|
||||
|
||||
@ -2,27 +2,31 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.human_input_adapter import DeliveryMethodType
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.entities import WorkflowExecution
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
@ -628,12 +632,12 @@ class TestPrivateWorkflowPauseEntity:
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Integration tests for _build_human_input_required_reason using real DB models."""
|
||||
|
||||
def test_builds_reason_from_form_definition(
|
||||
def test_prefers_standalone_web_app_token_when_available(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Build the graph pause reason from the stored form definition."""
|
||||
"""Use the public standalone web-app token for service API payloads."""
|
||||
|
||||
expiration_time = naive_utc_now()
|
||||
form_definition = FormDefinition(
|
||||
@ -660,6 +664,40 @@ class TestBuildHumanInputRequiredReason:
|
||||
db_session_with_containers.add(form_model)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_model.id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload="{}",
|
||||
)
|
||||
db_session_with_containers.add(delivery)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
backstage_access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=backstage_access_token,
|
||||
)
|
||||
console_access_token = secrets.token_urlsafe(8)
|
||||
console_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload="{}",
|
||||
access_token=console_access_token,
|
||||
)
|
||||
web_app_access_token = secrets.token_urlsafe(8)
|
||||
web_app_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
recipient_payload="{}",
|
||||
access_token=web_app_access_token,
|
||||
)
|
||||
db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient])
|
||||
db_session_with_containers.flush()
|
||||
# Create a pause so the reason has a valid pause_id
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
@ -688,8 +726,15 @@ class TestBuildHumanInputRequiredReason:
|
||||
# Refresh to ensure we have DB-round-tripped objects
|
||||
db_session_with_containers.refresh(form_model)
|
||||
db_session_with_containers.refresh(reason_model)
|
||||
db_session_with_containers.refresh(backstage_recipient)
|
||||
db_session_with_containers.refresh(console_recipient)
|
||||
db_session_with_containers.refresh(web_app_recipient)
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model)
|
||||
reason = _build_human_input_required_reason(
|
||||
reason_model,
|
||||
form_model,
|
||||
[backstage_recipient, console_recipient, web_app_recipient],
|
||||
)
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.node_title == "Ask Name"
|
||||
@ -697,3 +742,92 @@ class TestBuildHumanInputRequiredReason:
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
assert reason.resolved_default_values == {"name": "Alice"}
|
||||
assert not hasattr(reason, "form_token")
|
||||
|
||||
def test_falls_back_to_console_token_when_web_app_token_missing(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Use the console token only when no standalone web-app token exists."""
|
||||
|
||||
expiration_time = naive_utc_now()
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
|
||||
form_model = HumanInputForm(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
workflow_run_id=str(uuid4()),
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
db_session_with_containers.add(form_model)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_model.id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload="{}",
|
||||
)
|
||||
db_session_with_containers.add(delivery)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
backstage_access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=backstage_access_token,
|
||||
)
|
||||
console_access_token = secrets.token_urlsafe(8)
|
||||
console_recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload="{}",
|
||||
access_token=console_access_token,
|
||||
)
|
||||
db_session_with_containers.add_all([backstage_recipient, console_recipient])
|
||||
db_session_with_containers.flush()
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.flush()
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id=pause.id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=form_model.id,
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
db_session_with_containers.add(reason_model)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient])
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert not hasattr(reason, "form_token")
|
||||
|
||||
37
api/tests/unit_tests/commands/test_generate_swagger_specs.py
Normal file
37
api/tests/unit_tests/commands/test_generate_swagger_specs.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Unit tests for the standalone Swagger export helper."""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_generate_swagger_specs_module():
|
||||
api_dir = Path(__file__).resolve().parents[3]
|
||||
script_path = api_dir / "dev" / "generate_swagger_specs.py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location("generate_swagger_specs", script_path)
|
||||
assert spec
|
||||
assert spec.loader
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module) # type: ignore[attr-defined]
|
||||
return module
|
||||
|
||||
|
||||
def test_generate_specs_writes_console_web_and_service_swagger_files(tmp_path):
|
||||
module = _load_generate_swagger_specs_module()
|
||||
|
||||
written_paths = module.generate_specs(tmp_path)
|
||||
|
||||
assert [path.name for path in written_paths] == [
|
||||
"console-swagger.json",
|
||||
"web-swagger.json",
|
||||
"service-swagger.json",
|
||||
]
|
||||
|
||||
for path in written_paths:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
assert payload["swagger"] == "2.0"
|
||||
assert "paths" in payload
|
||||
@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
|
||||
@ -68,7 +68,10 @@ class TestChangeEmailSend:
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
@ -85,12 +88,55 @@ class TestChangeEmailSend:
|
||||
email="new@example.com",
|
||||
old_email="current@example.com",
|
||||
language="en-US",
|
||||
phase="new_email",
|
||||
phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@ -122,7 +168,12 @@ class TestChangeEmailValidity:
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
@ -138,11 +189,169 @@ class TestChangeEmailValidity:
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
"user@example.com",
|
||||
code="1234",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_upgrade_new_phase_token_to_new_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "new@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-verified-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "new@example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"new@example.com",
|
||||
code="1234",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
},
|
||||
)
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_validity_when_token_phase_is_unknown(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""A token whose phase marker is a string but not a known transition must be rejected."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else",
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailCheckApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_generate_token.assert_not_called()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_validity_when_token_has_no_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailCheckApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_generate_token.assert_not_called()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@ -175,7 +384,11 @@ class TestChangeEmailReset:
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"email": "new@example.com",
|
||||
"old_email": "OLD@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
@ -194,6 +407,155 @@ class TestChangeEmailReset:
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_reset_when_token_phase_is_not_new_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
# Simulate a token straight out of step #1 (phase=old_email) — exactly
|
||||
# the replay used in the advisory PoC.
|
||||
mock_get_data.return_value = {
|
||||
"email": "old@example.com",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "attacker@example.com", "token": "token-from-step1"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""A verified token for address A must not be replayed to change to address B."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"email": "verified@example.com",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "attacker@example.com", "token": "token-verified"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
|
||||
|
||||
class TestAccountServiceSendChangeEmailEmail:
|
||||
"""Service-level coverage for the phase-bound changes in `send_change_email_email`."""
|
||||
|
||||
def test_should_raise_value_error_for_invalid_phase(self):
|
||||
with pytest.raises(ValueError, match="phase must be one of"):
|
||||
AccountService.send_change_email_email(
|
||||
email="user@example.com",
|
||||
old_email="user@example.com",
|
||||
phase="old_email_verified",
|
||||
)
|
||||
|
||||
@patch("services.account_service.send_change_mail_task")
|
||||
@patch("services.account_service.AccountService.change_email_rate_limiter")
|
||||
@patch("services.account_service.AccountService.generate_change_email_token")
|
||||
def test_should_stamp_phase_into_generated_token(
|
||||
self,
|
||||
mock_generate_token,
|
||||
mock_rate_limiter,
|
||||
mock_mail_task,
|
||||
):
|
||||
mock_rate_limiter.is_rate_limited.return_value = False
|
||||
mock_generate_token.return_value = ("123456", "the-token")
|
||||
|
||||
returned = AccountService.send_change_email_email(
|
||||
email="user@example.com",
|
||||
old_email="user@example.com",
|
||||
language="en-US",
|
||||
phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
)
|
||||
|
||||
assert returned == "the-token"
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com",
|
||||
None,
|
||||
old_email="user@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
},
|
||||
)
|
||||
mock_mail_task.delay.assert_called_once()
|
||||
mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
|
||||
@ -2,14 +2,17 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace.endpoint import (
|
||||
EndpointCreateApi,
|
||||
EndpointDeleteApi,
|
||||
DeprecatedEndpointCreateApi,
|
||||
DeprecatedEndpointDeleteApi,
|
||||
DeprecatedEndpointUpdateApi,
|
||||
EndpointCollectionApi,
|
||||
EndpointDisableApi,
|
||||
EndpointEnableApi,
|
||||
EndpointItemApi,
|
||||
EndpointListApi,
|
||||
EndpointListForSinglePluginApi,
|
||||
EndpointUpdateApi,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
|
||||
@ -35,9 +38,9 @@ def patch_current_account(user_and_tenant):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointCreateApi:
|
||||
class TestEndpointCollectionApi:
|
||||
def test_create_success(self, app):
|
||||
api = EndpointCreateApi()
|
||||
api = EndpointCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -55,7 +58,7 @@ class TestEndpointCreateApi:
|
||||
assert result["success"] is True
|
||||
|
||||
def test_create_permission_denied(self, app):
|
||||
api = EndpointCreateApi()
|
||||
api = EndpointCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -75,7 +78,7 @@ class TestEndpointCreateApi:
|
||||
method(api)
|
||||
|
||||
def test_create_validation_error(self, app):
|
||||
api = EndpointCreateApi()
|
||||
api = EndpointCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -91,6 +94,27 @@ class TestEndpointCreateApi:
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestDeprecatedEndpointCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = DeprecatedEndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"name": "endpoint",
|
||||
"settings": {"a": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListApi:
|
||||
def test_list_success(self, app):
|
||||
@ -146,9 +170,96 @@ class TestEndpointListForSinglePluginApi:
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDeleteApi:
|
||||
class TestEndpointItemApi:
|
||||
def test_delete_success(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="DELETE"),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.delete_endpoint",
|
||||
return_value=True,
|
||||
) as mock_delete,
|
||||
):
|
||||
result = method(api, "e1")
|
||||
|
||||
assert result["success"] is True
|
||||
mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1")
|
||||
|
||||
def test_delete_service_failure(self, app):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="DELETE"),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api, "e1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
def test_update_success(self, app):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "new-name",
|
||||
"settings": {"x": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="PATCH", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.update_endpoint",
|
||||
return_value=True,
|
||||
) as mock_update,
|
||||
):
|
||||
result = method(api, "e1")
|
||||
|
||||
assert result["success"] is True
|
||||
mock_update.assert_called_once_with(
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
endpoint_id="e1",
|
||||
name="new-name",
|
||||
settings={"x": 1},
|
||||
)
|
||||
|
||||
def test_update_validation_error(self, app):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"settings": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="PATCH", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "e1")
|
||||
|
||||
def test_update_service_failure(self, app):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "n",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="PATCH", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api, "e1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestDeprecatedEndpointDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = DeprecatedEndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
@ -162,7 +273,7 @@ class TestEndpointDeleteApi:
|
||||
assert result["success"] is True
|
||||
|
||||
def test_delete_invalid_payload(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
api = DeprecatedEndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
@ -172,7 +283,7 @@ class TestEndpointDeleteApi:
|
||||
method(api)
|
||||
|
||||
def test_delete_service_failure(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
api = DeprecatedEndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
@ -187,9 +298,9 @@ class TestEndpointDeleteApi:
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointUpdateApi:
|
||||
class TestDeprecatedEndpointUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
api = DeprecatedEndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -207,7 +318,7 @@ class TestEndpointUpdateApi:
|
||||
assert result["success"] is True
|
||||
|
||||
def test_update_validation_error(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
api = DeprecatedEndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1", "settings": {}}
|
||||
@ -219,7 +330,7 @@ class TestEndpointUpdateApi:
|
||||
method(api)
|
||||
|
||||
def test_update_service_failure(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
api = DeprecatedEndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
@ -237,6 +348,36 @@ class TestEndpointUpdateApi:
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestEndpointRouteMetadata:
|
||||
def test_legacy_write_routes_are_marked_deprecated(self):
|
||||
assert DeprecatedEndpointCreateApi.post.__apidoc__["deprecated"] is True
|
||||
assert DeprecatedEndpointDeleteApi.post.__apidoc__["deprecated"] is True
|
||||
assert DeprecatedEndpointUpdateApi.post.__apidoc__["deprecated"] is True
|
||||
assert EndpointCollectionApi.post.__apidoc__.get("deprecated") is not True
|
||||
assert EndpointItemApi.delete.__apidoc__.get("deprecated") is not True
|
||||
assert EndpointItemApi.patch.__apidoc__.get("deprecated") is not True
|
||||
|
||||
def test_canonical_and_legacy_write_routes_are_registered(self):
|
||||
route_map = {
|
||||
resource.__name__: urls
|
||||
for resource, urls, _route_doc, _kwargs in console_ns.resources
|
||||
if resource.__name__
|
||||
in {
|
||||
"EndpointCollectionApi",
|
||||
"EndpointItemApi",
|
||||
"DeprecatedEndpointCreateApi",
|
||||
"DeprecatedEndpointDeleteApi",
|
||||
"DeprecatedEndpointUpdateApi",
|
||||
}
|
||||
}
|
||||
|
||||
assert route_map["EndpointCollectionApi"] == ("/workspaces/current/endpoints",)
|
||||
assert route_map["EndpointItemApi"] == ("/workspaces/current/endpoints/<string:id>",)
|
||||
assert route_map["DeprecatedEndpointCreateApi"] == ("/workspaces/current/endpoints/create",)
|
||||
assert route_map["DeprecatedEndpointDeleteApi"] == ("/workspaces/current/endpoints/delete",)
|
||||
assert route_map["DeprecatedEndpointUpdateApi"] == ("/workspaces/current/endpoints/update",)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointEnableApi:
|
||||
def test_enable_success(self, app):
|
||||
|
||||
@ -0,0 +1,707 @@
|
||||
"""Dedicated tests for HITL behavior exposed through the Service API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
import services.app_generate_service as ags_module
|
||||
from controllers.service_api.app.workflow_events import WorkflowEventsApi
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.common import workflow_response_converter
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
HumanInputRequiredResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
from graphon.nodes.human_input.enums import FormInputType
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_event_snapshot_service import _build_snapshot_events
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class _DummyRateLimit:
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return "dummy-request-id"
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int) -> None:
|
||||
self.client_id = client_id
|
||||
self.max_active_requests = max_active_requests
|
||||
|
||||
def enter(self, request_id: str | None = None) -> str:
|
||||
return request_id or "dummy-request-id"
|
||||
|
||||
def exit(self, request_id: str) -> None:
|
||||
return None
|
||||
|
||||
def generate(self, generator, request_id: str):
|
||||
return generator
|
||||
|
||||
|
||||
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
|
||||
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
|
||||
return workflow_events_module
|
||||
|
||||
|
||||
def _build_service_api_pause_converter() -> WorkflowResponseConverter:
|
||||
application_generate_entity = SimpleNamespace(
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
|
||||
)
|
||||
system_variables = build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
)
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-id"
|
||||
user.name = "Tester"
|
||||
user.email = "tester@example.com"
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse:
|
||||
data = AdvancedChatPausedBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
workflow_run_id="run-1",
|
||||
answer="partial",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[
|
||||
{
|
||||
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"expiration_time": 100,
|
||||
}
|
||||
],
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
)
|
||||
return AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
|
||||
def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id="r1",
|
||||
workflow_id="wf-1",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=0.5,
|
||||
total_tokens=0,
|
||||
total_steps=2,
|
||||
created_at=1,
|
||||
finished_at=None,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FakePauseEntity(WorkflowPauseEntity):
|
||||
pause_id: str
|
||||
workflow_run_id: str
|
||||
paused_at_value: datetime
|
||||
pause_reasons: Sequence[HumanInputRequired]
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.pause_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
raise AssertionError("state is not required for snapshot tests")
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self.paused_at_value
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
|
||||
return self.pause_reasons
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"input": "value"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
|
||||
created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
title="Human Input",
|
||||
index=1,
|
||||
status=status.value,
|
||||
elapsed_time=0.5,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=None,
|
||||
loop_id=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.register_paused_node("node-1")
|
||||
runtime_state.outputs = {"result": "value"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
class TestHitlServiceApi:
|
||||
# Service API event-stream continuation
|
||||
def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
msg_generator.retrieve_events.return_value = ["raw-event"]
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: streamed\n\n"
|
||||
msg_generator.retrieve_events.assert_called_once_with(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
terminal_events=[],
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
|
||||
|
||||
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
|
||||
snapshot_builder = Mock(return_value=["snapshot-events"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
|
||||
method="GET",
|
||||
):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: snapshot\n\n"
|
||||
msg_generator.retrieve_events.assert_not_called()
|
||||
snapshot_builder.assert_called_once_with(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
session_maker=ANY,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=False,
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
|
||||
|
||||
def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
|
||||
monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
|
||||
monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
|
||||
|
||||
generator_instance = MagicMock()
|
||||
generator_instance.generate.return_value = {"result": "advanced-blocking"}
|
||||
generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
|
||||
monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
|
||||
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.ADVANCED_CHAT
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"workflow_id": None, "query": "hi", "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "advanced-blocking"}
|
||||
call_kwargs = generator_instance.generate.call_args.kwargs
|
||||
assert call_kwargs["streaming"] is False
|
||||
assert call_kwargs["pause_state_config"] is not None
|
||||
assert call_kwargs["pause_state_config"].session_factory == "session-maker"
|
||||
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
|
||||
|
||||
# Blocking payload contract
|
||||
def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
|
||||
_build_advanced_chat_paused_blocking_response()
|
||||
)
|
||||
|
||||
assert response["event"] == "workflow_paused"
|
||||
assert response["workflow_run_id"] == "run-1"
|
||||
assert response["answer"] == "partial"
|
||||
assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
assert response["data"]["reasons"][0]["expiration_time"] == 100
|
||||
assert "human_input_forms" not in response["data"]
|
||||
|
||||
def test_workflow_blocking_pause_payload_contract(self) -> None:
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
|
||||
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
|
||||
_build_workflow_paused_blocking_response()
|
||||
)
|
||||
|
||||
assert response["workflow_run_id"] == "r1"
|
||||
assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
|
||||
assert response["data"]["paused_nodes"] == ["node-1"]
|
||||
assert response["data"]["reasons"] == [
|
||||
{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}
|
||||
]
|
||||
assert "human_input_forms" not in response["data"]
|
||||
|
||||
def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
|
||||
from core.app.app_config.entities import AppAdditionalFeatures
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from models.enums import MessageStatus
|
||||
from models.model import EndUser
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
|
||||
message=SimpleNamespace(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=datetime.utcnow(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
),
|
||||
user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
|
||||
stream=False,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_content="Need approval",
|
||||
inputs=[],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
expiration_time=123,
|
||||
),
|
||||
)
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run-id",
|
||||
paused_nodes=["node-1"],
|
||||
outputs={},
|
||||
reasons=[
|
||||
{
|
||||
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"node_id": "node-1",
|
||||
"expiration_time": 123,
|
||||
},
|
||||
],
|
||||
status="paused",
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, AdvancedChatPausedBlockingResponse)
|
||||
assert response.data.answer == "partial answer"
|
||||
assert response.data.workflow_run_id == "run-id"
|
||||
assert response.data.reasons[0]["form_id"] == "form-1"
|
||||
assert response.data.reasons[0]["expiration_time"] == 123
|
||||
|
||||
def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras={},
|
||||
call_depth=0,
|
||||
)
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
user=SimpleNamespace(id="user", session_id="session"),
|
||||
stream=False,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000)
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_content="content",
|
||||
expiration_time=1,
|
||||
),
|
||||
)
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, WorkflowAppPausedBlockingResponse)
|
||||
assert response.data.status == WorkflowExecutionStatus.PAUSED
|
||||
assert response.data.paused_nodes == ["node-1"]
|
||||
assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}]
|
||||
|
||||
def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
converter = _build_service_api_pause_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeSession:
|
||||
def execute(self, _stmt):
|
||||
return [("form-1", expiration_time, '{"display_in_ui": true}')]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
|
||||
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
workflow_response_converter,
|
||||
"load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "token"},
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="Rendered",
|
||||
inputs=[
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
|
||||
],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
outputs={"answer": "value"},
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
|
||||
responses = converter.workflow_pause_to_stream_response(
|
||||
event=queue_event,
|
||||
task_id="task",
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {}
|
||||
assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required"
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
assert pause_resp.data.reasons[0]["form_token"] == "token"
|
||||
assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
|
||||
|
||||
assert isinstance(responses[0], HumanInputRequiredResponse)
|
||||
hi_resp = responses[0]
|
||||
assert hi_resp.data.form_id == "form-1"
|
||||
assert hi_resp.data.node_id == "node-id"
|
||||
assert hi_resp.data.node_title == "Human Step"
|
||||
assert hi_resp.data.inputs[0].output_variable_name == "field"
|
||||
assert hi_resp.data.actions[0].id == "approve"
|
||||
assert hi_resp.data.display_in_ui is True
|
||||
assert hi_resp.data.form_token == "token"
|
||||
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
|
||||
|
||||
# Snapshot payload contract
|
||||
def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
|
||||
resumption_context = _build_resumption_context("task-ctx")
|
||||
monkeypatch.setattr(
|
||||
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
|
||||
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
|
||||
)
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def session_maker() -> _SessionContext:
|
||||
return _SessionContext(
|
||||
SimpleNamespace(
|
||||
execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
|
||||
)
|
||||
)
|
||||
|
||||
pause_entity = _FakePauseEntity(
|
||||
pause_id="pause-1",
|
||||
workflow_run_id="run-1",
|
||||
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
pause_reasons=[
|
||||
HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_token="wtok",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-ctx",
|
||||
message_context=None,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"human_input_required",
|
||||
"workflow_paused",
|
||||
]
|
||||
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
|
||||
assert events[3]["data"]["form_token"] == "wtok"
|
||||
assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
|
||||
pause_data = events[-1]["data"]
|
||||
assert pause_data["paused_nodes"] == ["node-1"]
|
||||
assert pause_data["outputs"] == {"result": "value"}
|
||||
assert pause_data["reasons"][0]["TYPE"] == "human_input_required"
|
||||
assert pause_data["reasons"][0]["form_token"] == "wtok"
|
||||
assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
|
||||
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
|
||||
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
|
||||
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
|
||||
assert pause_data["total_tokens"] == workflow_run.total_tokens
|
||||
assert pause_data["total_steps"] == workflow_run.total_steps
|
||||
@ -0,0 +1,184 @@
|
||||
"""Unit tests for Service API human input form endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
|
||||
from models.human_input import RecipientType
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class TestWorkflowHumanInputFormApi:
|
||||
def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
definition = SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"rendered_content": "Rendered form content",
|
||||
"inputs": [{"output_variable_name": "name"}],
|
||||
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
|
||||
"user_actions": [{"id": "approve", "title": "Approve"}],
|
||||
}
|
||||
)
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
get_definition=lambda: definition,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
response = handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload == {
|
||||
"form_content": "Rendered form content",
|
||||
"inputs": [{"output_variable_name": "name"}],
|
||||
"resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'},
|
||||
"user_actions": [{"id": "approve", "title": "Approve"}],
|
||||
"expiration_time": int(form.expiration_time.timestamp()),
|
||||
}
|
||||
service_mock.get_form_by_token.assert_called_once_with("token-1")
|
||||
service_mock.ensure_form_active.assert_called_once_with(form)
|
||||
|
||||
def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="another-app",
|
||||
tenant_id="tenant-1",
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"recipient_type",
|
||||
[
|
||||
RecipientType.CONSOLE,
|
||||
RecipientType.BACKSTAGE,
|
||||
RecipientType.EMAIL_MEMBER,
|
||||
RecipientType.EMAIL_EXTERNAL,
|
||||
],
|
||||
)
|
||||
def test_get_rejects_non_service_api_recipient_types(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
|
||||
) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=recipient_type,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, form_token="token-1")
|
||||
|
||||
service_mock.ensure_form_active.assert_not_called()
|
||||
|
||||
def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
|
||||
):
|
||||
response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
|
||||
|
||||
assert response == {}
|
||||
assert status == 200
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token-1",
|
||||
selected_action_id="approve",
|
||||
form_data={"name": "Alice"},
|
||||
submission_end_user_id="end-user-1",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"recipient_type",
|
||||
[
|
||||
RecipientType.CONSOLE,
|
||||
RecipientType.BACKSTAGE,
|
||||
RecipientType.EMAIL_MEMBER,
|
||||
RecipientType.EMAIL_EXTERNAL,
|
||||
],
|
||||
)
|
||||
def test_post_rejects_non_service_api_recipient_types(
|
||||
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
|
||||
) -> None:
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=recipient_type,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
|
||||
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
|
||||
|
||||
service_mock.submit_form_by_token.assert_not_called()
|
||||
@ -0,0 +1,166 @@
|
||||
"""Unit tests for Service API workflow event stream endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.app.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
|
||||
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
|
||||
return workflow_events_module
|
||||
|
||||
|
||||
class TestWorkflowEventsApi:
|
||||
def test_wrong_app_mode(self, app) -> None:
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=None)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="another-user",
|
||||
finished_at=None,
|
||||
)
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
monkeypatch.setattr(
|
||||
workflow_events_module.WorkflowResponseConverter,
|
||||
"workflow_run_result_to_finish_response",
|
||||
lambda **_kwargs: SimpleNamespace(
|
||||
model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"},
|
||||
event=SimpleNamespace(value="workflow_finished"),
|
||||
),
|
||||
)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
body = response.get_data(as_text=True).strip()
|
||||
assert body.startswith("data: ")
|
||||
payload = json.loads(body[len("data: ") :])
|
||||
assert payload["task_id"] == "run-1"
|
||||
assert payload["event"] == "workflow_finished"
|
||||
|
||||
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
msg_generator.retrieve_events.return_value = ["raw-event"]
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: streamed\n\n"
|
||||
msg_generator.retrieve_events.assert_called_once_with(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
terminal_events=None,
|
||||
)
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
|
||||
|
||||
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="end-user-1",
|
||||
finished_at=None,
|
||||
)
|
||||
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
msg_generator = Mock()
|
||||
workflow_generator = Mock()
|
||||
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
|
||||
snapshot_builder = Mock(return_value=["snapshot-events"])
|
||||
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
|
||||
|
||||
assert response.get_data(as_text=True) == "data: snapshot\n\n"
|
||||
msg_generator.retrieve_events.assert_not_called()
|
||||
snapshot_builder.assert_called_once()
|
||||
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
|
||||
@ -22,6 +22,8 @@ import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.service_api.dataset.document import (
|
||||
DeprecatedDocumentAddByTextApi,
|
||||
DeprecatedDocumentUpdateByTextApi,
|
||||
DocumentAddByFileApi,
|
||||
DocumentAddByTextApi,
|
||||
DocumentApi,
|
||||
@ -1005,7 +1007,7 @@ class TestDocumentAddByTextApi:
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/document/create_by_text",
|
||||
f"/datasets/{mock_dataset.id}/document/create-by-text",
|
||||
method="POST",
|
||||
json={
|
||||
"name": "Test Document",
|
||||
@ -1037,7 +1039,7 @@ class TestDocumentAddByTextApi:
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/document/create_by_text",
|
||||
f"/datasets/{mock_dataset.id}/document/create-by-text",
|
||||
method="POST",
|
||||
json={"name": "Test Document", "text": "Content"},
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
@ -1066,7 +1068,7 @@ class TestDocumentAddByTextApi:
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/document/create_by_text",
|
||||
f"/datasets/{mock_dataset.id}/document/create-by-text",
|
||||
method="POST",
|
||||
json={"name": "Test Document", "text": "Content"},
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
@ -1093,6 +1095,20 @@ class TestArchivedDocumentImmutableError:
|
||||
assert error.code == 403
|
||||
|
||||
|
||||
class TestDocumentTextRouteDeprecation:
|
||||
"""Test that legacy underscore text routes stay marked deprecated."""
|
||||
|
||||
def test_create_by_text_legacy_alias_is_deprecated(self):
|
||||
"""Ensure only the legacy create-by-text alias is marked deprecated."""
|
||||
assert DeprecatedDocumentAddByTextApi.post.__apidoc__["deprecated"] is True
|
||||
assert DocumentAddByTextApi.post.__apidoc__.get("deprecated") is not True
|
||||
|
||||
def test_update_by_text_legacy_alias_is_deprecated(self):
|
||||
"""Ensure only the legacy update-by-text alias is marked deprecated."""
|
||||
assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True
|
||||
assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi,
|
||||
# DocumentUpdateByFileApi.
|
||||
@ -1162,7 +1178,7 @@ class TestDocumentUpdateByTextApiPost:
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text",
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text",
|
||||
method="POST",
|
||||
json={"name": "Updated Doc", "text": "New content"},
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
@ -1195,7 +1211,7 @@ class TestDocumentUpdateByTextApiPost:
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text",
|
||||
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text",
|
||||
method="POST",
|
||||
json={"name": "Doc", "text": "Content"},
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -10,7 +13,8 @@ from core.app.entities.task_entities import (
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateResponseConverter:
|
||||
@ -28,6 +32,37 @@ class TestAdvancedChatGenerateResponseConverter:
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch):
|
||||
data = AdvancedChatPausedBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
workflow_run_id="run-1",
|
||||
answer="partial",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
paused_nodes=["node-1"],
|
||||
reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}],
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
)
|
||||
original_model_dump = type(data).model_dump
|
||||
|
||||
def _model_dump_with_future_field(self, *args, **kwargs):
|
||||
payload = original_model_dump(self, *args, **kwargs)
|
||||
payload["future_field"] = "future-value"
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field)
|
||||
blocking = AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert response["data"]["future_field"] == "future-value"
|
||||
|
||||
def test_stream_simple_response_includes_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
|
||||
@ -39,15 +39,19 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.nodes.human_input.entities import UserAction
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import MessageStatus
|
||||
@ -123,6 +127,57 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
assert response.data.answer == "done"
|
||||
assert response.data.metadata == {"k": "v"}
|
||||
|
||||
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
total_tokens=7,
|
||||
node_run_steps=3,
|
||||
)
|
||||
|
||||
def _gen():
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_content="Need approval",
|
||||
inputs=[],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
expiration_time=123,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, AdvancedChatPausedBlockingResponse)
|
||||
assert response.data.workflow_run_id == "run-id"
|
||||
assert response.data.status == "paused"
|
||||
assert response.data.paused_nodes == ["node-1"]
|
||||
assert response.data.reasons == [
|
||||
{
|
||||
"TYPE": PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
"form_id": "form-1",
|
||||
"node_id": "node-1",
|
||||
"node_title": "Approval",
|
||||
"form_content": "Need approval",
|
||||
"inputs": [],
|
||||
"actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
|
||||
"display_in_ui": True,
|
||||
"form_token": "token-1",
|
||||
"resolved_default_values": {},
|
||||
"expiration_time": 123,
|
||||
}
|
||||
]
|
||||
|
||||
def test_handle_text_chunk_event_updates_state(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager = SimpleNamespace(
|
||||
|
||||
@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
|
||||
|
||||
class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||
blocking_full_calls: list[WorkflowAppBlockingResponse] = []
|
||||
blocking_simple_calls: list[WorkflowAppBlockingResponse] = []
|
||||
stream_full_calls: list[Generator[AppStreamResponse, None, None]] = []
|
||||
stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = []
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
cls.blocking_full_calls = []
|
||||
cls.blocking_simple_calls = []
|
||||
cls.stream_full_calls = []
|
||||
cls.stream_simple_calls = []
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
cls.blocking_full_calls.append(blocking_response)
|
||||
return {"kind": "blocking-full", "task_id": blocking_response.task_id}
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
cls.blocking_simple_calls.append(blocking_response)
|
||||
return {"kind": "blocking-simple", "task_id": blocking_response.task_id}
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls.stream_full_calls.append(stream_response)
|
||||
yield {"kind": "stream-full"}
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls.stream_simple_calls.append(stream_response)
|
||||
yield {"kind": "stream-simple"}
|
||||
|
||||
|
||||
def _build_blocking_response() -> WorkflowAppBlockingResponse:
|
||||
return WorkflowAppBlockingResponse(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id="run-1",
|
||||
workflow_id="workflow-1",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"ok": True},
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_stream_response() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run-1",
|
||||
stream_response=PingStreamResponse(task_id="task-1"),
|
||||
)
|
||||
|
||||
|
||||
def test_convert_routes_blocking_response_by_invoke_from() -> None:
|
||||
_DummyConverter.reset()
|
||||
blocking_response = _build_blocking_response()
|
||||
|
||||
full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API)
|
||||
simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP)
|
||||
|
||||
assert full_result == {"kind": "blocking-full", "task_id": "task-1"}
|
||||
assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"}
|
||||
assert _DummyConverter.blocking_full_calls == [blocking_response]
|
||||
assert _DummyConverter.blocking_simple_calls == [blocking_response]
|
||||
|
||||
|
||||
def test_convert_routes_stream_response_by_invoke_from() -> None:
|
||||
_DummyConverter.reset()
|
||||
|
||||
full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API))
|
||||
simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP))
|
||||
|
||||
assert full_result == [{"kind": "stream-full"}]
|
||||
assert simple_result == [{"kind": "stream-simple"}]
|
||||
assert len(_DummyConverter.stream_full_calls) == 1
|
||||
assert len(_DummyConverter.stream_simple_calls) == 1
|
||||
@ -1,6 +1,7 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@ -23,7 +24,21 @@ class TestMessageGenerator:
|
||||
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
|
||||
) as mock_stream,
|
||||
):
|
||||
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
|
||||
events = list(
|
||||
MessageGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
"run-1",
|
||||
idle_timeout=1,
|
||||
ping_interval=2,
|
||||
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
|
||||
)
|
||||
)
|
||||
|
||||
assert events == [{"event": "ping"}]
|
||||
mock_stream.assert_called_once()
|
||||
mock_stream.assert_called_once_with(
|
||||
topic="topic",
|
||||
idle_timeout=1,
|
||||
ping_interval=2,
|
||||
on_subscribe=None,
|
||||
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
|
||||
)
|
||||
|
||||
@ -88,6 +88,10 @@ def test_normalize_terminal_events_defaults():
|
||||
}
|
||||
|
||||
|
||||
def test_normalize_terminal_events_empty_values():
|
||||
assert _normalize_terminal_events([]) == set({})
|
||||
|
||||
|
||||
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
|
||||
@ -106,3 +110,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
# next receive yields None -> ping interval triggers
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
|
||||
|
||||
def test_stream_topic_events_can_continue_past_pause():
|
||||
topic = FakeTopic()
|
||||
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode())
|
||||
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode())
|
||||
|
||||
generator = stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=1.0,
|
||||
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value
|
||||
assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
@ -36,11 +36,12 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
@ -91,27 +92,50 @@ def _make_pipeline():
|
||||
|
||||
|
||||
class TestWorkflowGenerateTaskPipeline:
|
||||
def test_to_blocking_response_handles_pause(self):
|
||||
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
total_tokens=5,
|
||||
node_run_steps=2,
|
||||
)
|
||||
|
||||
def _gen():
|
||||
yield WorkflowPauseStreamResponse(
|
||||
yield HumanInputRequiredResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
workflow_run_id="run-id",
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_content="content",
|
||||
expiration_time=1,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert isinstance(response, WorkflowAppPausedBlockingResponse)
|
||||
assert response.workflow_run_id == "run-id"
|
||||
assert response.data.status == WorkflowExecutionStatus.PAUSED
|
||||
assert response.data.created_at == 0
|
||||
assert response.data.paused_nodes == ["node-1"]
|
||||
assert response.data.reasons == [
|
||||
{
|
||||
"TYPE": "human_input_required",
|
||||
"form_id": "form-1",
|
||||
"node_id": "node-1",
|
||||
"node_title": "Human Input",
|
||||
"form_content": "content",
|
||||
"inputs": [],
|
||||
"actions": [],
|
||||
"display_in_ui": False,
|
||||
"form_token": None,
|
||||
"resolved_default_values": {},
|
||||
"expiration_time": 1,
|
||||
}
|
||||
]
|
||||
|
||||
def test_to_blocking_response_handles_finish(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
106
api/tests/unit_tests/core/helper/test_creators.py
Normal file
106
api/tests/unit_tests/core/helper/test_creators.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Tests for the Creators Platform helper module."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_creators_url(monkeypatch):
|
||||
"""Patch the module-level creators_platform_api_url for all tests."""
|
||||
monkeypatch.setattr(
|
||||
"core.helper.creators.creators_platform_api_url",
|
||||
URL("https://creators.example.com"),
|
||||
)
|
||||
|
||||
|
||||
class TestUploadDSL:
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_returns_claim_code(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"data": {"claim_code": "abc123"}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
result = upload_dsl(b"app: demo", "demo.yaml")
|
||||
|
||||
assert result == "abc123"
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert "anonymous-upload" in call_kwargs.args[0]
|
||||
assert call_kwargs.kwargs["timeout"] == 30
|
||||
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_raises_on_missing_claim_code(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"data": {}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
with pytest.raises(ValueError, match="claim_code"):
|
||||
upload_dsl(b"app: demo")
|
||||
|
||||
@patch("core.helper.creators.httpx.post")
|
||||
def test_raises_on_http_error(self, mock_post):
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Server Error",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(),
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from core.helper.creators import upload_dsl
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
upload_dsl(b"app: demo")
|
||||
|
||||
|
||||
class TestGetRedirectUrl:
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_without_oauth_client_id(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
|
||||
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
assert "dsl_claim_code=claim-abc" in url
|
||||
assert "oauth_code" not in url
|
||||
assert url.startswith("https://creators.example.com")
|
||||
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_with_oauth_client_id(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz"
|
||||
|
||||
with patch(
|
||||
"services.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="oauth-code-123",
|
||||
) as mock_sign:
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
mock_sign.assert_called_once_with("client-xyz", "user-1")
|
||||
assert "dsl_claim_code=claim-abc" in url
|
||||
assert "oauth_code=oauth-code-123" in url
|
||||
|
||||
@patch("core.helper.creators.dify_config")
|
||||
def test_strips_trailing_slash(self, mock_config):
|
||||
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/"
|
||||
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
|
||||
|
||||
from core.helper.creators import get_redirect_url
|
||||
|
||||
url = get_redirect_url("user-1", "claim-abc")
|
||||
|
||||
assert url.startswith("https://creators.example.com?")
|
||||
assert "creators.example.com/?" not in url
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user