Merge remote-tracking branch 'origin/main' into deploy/dev

# Conflicts:
#	api/core/app/entities/task_entities.py
#	api/core/llm_generator/llm_generator.py
#	api/core/llm_generator/output_parser/suggested_questions_after_answer.py
#	api/core/llm_generator/prompts.py
#	api/models/comment.py
#	api/services/workflow_event_snapshot_service.py
#	api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py
#	api/tests/unit_tests/core/llm_generator/test_llm_generator.py
#	api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py
#	eslint-suppressions.json
#	web/app/components/app/app-publisher/index.tsx
#	web/app/components/datasets/settings/permission-selector/index.tsx
#	web/app/components/plugins/plugin-page/plugin-tasks/__tests__/index.spec.tsx
#	web/app/components/workflow/block-selector/blocks.tsx
#	web/app/components/workflow/block-selector/main.tsx
#	web/i18n/ar-TN/app-debug.json
#	web/i18n/de-DE/app-debug.json
#	web/i18n/en-US/workflow.json
#	web/i18n/es-ES/app-debug.json
#	web/i18n/fa-IR/app-debug.json
#	web/i18n/fr-FR/app-debug.json
#	web/i18n/hi-IN/app-debug.json
#	web/i18n/id-ID/app-debug.json
#	web/i18n/it-IT/app-debug.json
#	web/i18n/ja-JP/app-debug.json
#	web/i18n/ko-KR/app-debug.json
#	web/i18n/nl-NL/app-debug.json
#	web/i18n/pl-PL/app-debug.json
#	web/i18n/pt-BR/app-debug.json
#	web/i18n/ro-RO/app-debug.json
#	web/i18n/ru-RU/app-debug.json
#	web/i18n/sl-SI/app-debug.json
#	web/i18n/th-TH/app-debug.json
#	web/i18n/tr-TR/app-debug.json
#	web/i18n/uk-UA/app-debug.json
#	web/i18n/vi-VN/app-debug.json
#	web/i18n/zh-Hans/workflow.json
#	web/i18n/zh-Hant/app-debug.json
This commit is contained in:
Junyan Qin 2026-04-24 16:27:59 +08:00
commit d92946e241
481 changed files with 21420 additions and 8276 deletions

View File

@ -367,7 +367,7 @@ For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check:tsgo
│ 3. Run: pnpm type-check
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │

View File

@ -127,7 +127,7 @@ For the current file being tested:
- [ ] Run full directory test: `pnpm test path/to/directory/`
- [ ] Check coverage report: `pnpm test:coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
- [ ] Run `pnpm type-check`
## Common Issues to Watch

View File

@ -1,19 +0,0 @@
name: Anti-Slop PR Check
on:
pull_request_target:
types: [opened, edited, synchronize]
permissions:
pull-requests: write
contents: read
jobs:
anti-slop:
runs-on: ubuntu-latest
steps:
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
close-pr: false
failure-add-pr-labels: "needs-revision"

View File

@ -35,7 +35,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -84,7 +84,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -105,7 +105,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -156,7 +156,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -25,7 +25,7 @@ jobs:
- name: Check Docker Compose inputs
if: github.event_name != 'merge_group'
id: docker-compose-changes
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
docker/generate_docker_compose
@ -35,7 +35,7 @@ jobs:
- name: Check web inputs
if: github.event_name != 'merge_group'
id: web-changes
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
@ -48,7 +48,7 @@ jobs:
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@ -58,7 +58,7 @@ jobs:
python-version: "3.11"
- if: github.event_name != 'merge_group'
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
- name: Generate Docker Compose
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
@ -123,4 +123,4 @@ jobs:
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
- if: github.event_name != 'merge_group'
uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4

View File

@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true

View File

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true

View File

@ -25,7 +25,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: false
python-version: "3.12"
@ -73,7 +73,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
@ -93,7 +93,7 @@ jobs:
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
@ -122,7 +122,7 @@ jobs:
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
@ -140,7 +140,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
**.sh

View File

@ -30,7 +30,7 @@ jobs:
persist-credentials: false
- name: Use Node.js
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
node-version: 22
cache: ''

View File

@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -65,7 +65,7 @@ jobs:
# tiflash
- name: Set up Full Vector Store Matrix
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@ -33,7 +33,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -62,7 +62,7 @@ jobs:
# tiflash
- name: Set up Vector Stores for Smoke Coverage
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@ -28,7 +28,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"

4
.gitignore vendored
View File

@ -236,6 +236,10 @@ scripts/stress-test/reports/
.playwright-mcp/
.serena/
# vitest browser mode attachments (failure screenshots, traces, etc.)
.vitest-attachments/
**/__screenshots__/
# settings
*.local.json
*.local.md

View File

@ -30,7 +30,7 @@ The codebase is split into:
## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types.
## General Practices

View File

@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

View File

@ -101,3 +101,11 @@ The scripts resolve paths relative to their location, so you can run them from a
uv run ruff format ./ # Format code
uv run basedpyright . # Type checking
```
## Generate TS stub
```
uv run dev/generate_swagger_specs.py --output-dir openapi
```
use https://jsontotable.org/openapi-to-typescript to convert to typescript

View File

@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -1405,6 +1426,7 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
action: str

View File

@ -728,6 +728,32 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")

View File

@ -8,10 +8,10 @@ from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@staticmethod
def _ensure_console_recipient_type(form: Form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
raise NotFoundError("form not found")
@setup_required
@login_required
@account_initialization_required
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."

View File

@ -595,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
# Default to the initial phase; any legacy/unexpected client input is
# coerced back to `old_email` so we never trust the caller to declare
# later phases without a verified predecessor token.
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
# The token used to request a new-email code must come from the
# old-email verification step. This prevents the bypass described
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@ -620,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
phase=send_phase,
)
return {"result": "success", "data": token}
@ -655,12 +667,31 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Only advance tokens that were minted by the matching send-code step;
# refuse tokens that have already progressed or lack a phase marker so
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
# is strictly enforced.
phase_transitions = {
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if not isinstance(token_phase, str):
raise InvalidTokenError()
refreshed_phase = phase_transitions.get(token_phase)
if refreshed_phase is None:
raise InvalidTokenError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
# Refresh token data by generating a new token that carries the
# upgraded phase so later steps can check it.
_, new_token = AccountService.generate_change_email_token(
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
user_email,
code=args.code,
old_email=token_data.get("old_email"),
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
@ -690,13 +721,29 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args.token)
# Only tokens that completed both verification phases may be used to
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
# the initial send-code step could be replayed directly here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
raise InvalidTokenError()
# Bind the new email to the token that was mailed and verified, so a
# verified token cannot be reused with a different `new_email` value.
token_email = reset_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != normalized_new_email:
raise InvalidTokenError()
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
# Revoke only after all checks pass so failed attempts don't burn a
# legitimately verified token.
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(

View File

@ -1,3 +1,11 @@
"""Console workspace endpoint controllers.
This module exposes workspace-scoped plugin endpoint management APIs. The
canonical write routes follow resource-oriented paths, while the historical
verb-based aliases stay available as deprecated resources so OpenAPI metadata
marks only the legacy paths as deprecated.
"""
from typing import Any
from flask import request
@ -25,7 +33,12 @@ class EndpointIdPayload(BaseModel):
endpoint_id: str
class EndpointUpdatePayload(EndpointIdPayload):
class EndpointUpdatePayload(BaseModel):
settings: dict[str, Any]
name: str = Field(min_length=1)
class LegacyEndpointUpdatePayload(EndpointIdPayload):
settings: dict[str, Any]
name: str = Field(min_length=1)
@ -76,6 +89,7 @@ register_schema_models(
EndpointCreatePayload,
EndpointIdPayload,
EndpointUpdatePayload,
LegacyEndpointUpdatePayload,
EndpointListQuery,
EndpointListForPluginQuery,
EndpointCreateResponse,
@ -88,8 +102,60 @@ register_schema_models(
)
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
def _create_endpoint() -> dict[str, bool]:
"""Create a plugin endpoint for the current workspace."""
user, tenant_id = current_account_with_tenant()
args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
)
}
except PluginPermissionDeniedError as e:
raise ValueError(e.description) from e
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
"""Update a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=args.name,
settings=args.settings,
)
}
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
"""Delete a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
)
}
@console_ns.route("/workspaces/current/endpoints")
class EndpointCollectionApi(Resource):
"""Canonical collection resource for endpoint creation."""
@console_ns.doc("create_endpoint")
@console_ns.doc(description="Create a new plugin endpoint")
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
@ -104,22 +170,33 @@ class EndpointCreateApi(Resource):
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
return _create_endpoint()
args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
)
}
except PluginPermissionDeniedError as e:
raise ValueError(e.description) from e
@console_ns.route("/workspaces/current/endpoints/create")
class DeprecatedEndpointCreateApi(Resource):
"""Deprecated verb-based alias for endpoint creation."""
@console_ns.doc("create_endpoint_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(
description=(
"Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead."
)
)
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
@console_ns.response(
200,
"Endpoint created successfully",
console_ns.models[EndpointCreateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
return _create_endpoint()
@console_ns.route("/workspaces/current/endpoints/list")
@ -190,10 +267,56 @@ class EndpointListForSinglePluginApi(Resource):
)
@console_ns.route("/workspaces/current/endpoints/delete")
class EndpointDeleteApi(Resource):
@console_ns.route("/workspaces/current/endpoints/<string:id>")
class EndpointItemApi(Resource):
"""Canonical item resource for endpoint updates and deletion."""
@console_ns.doc("delete_endpoint")
@console_ns.doc(description="Delete a plugin endpoint")
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
@console_ns.response(
200,
"Endpoint deleted successfully",
console_ns.models[EndpointDeleteResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, id: str):
return _delete_endpoint(endpoint_id=id)
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
@console_ns.response(
200,
"Endpoint updated successfully",
console_ns.models[EndpointUpdateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def patch(self, id: str):
return _update_endpoint(endpoint_id=id)
@console_ns.route("/workspaces/current/endpoints/delete")
class DeprecatedEndpointDeleteApi(Resource):
"""Deprecated verb-based alias for endpoint deletion."""
@console_ns.doc("delete_endpoint_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(
description=(
"Deprecated legacy alias for deleting a plugin endpoint. "
"Use DELETE /workspaces/current/endpoints/{id} instead."
)
)
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.response(
200,
@ -206,22 +329,23 @@ class EndpointDeleteApi(Resource):
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
}
return _delete_endpoint(endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/update")
class EndpointUpdateApi(Resource):
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
class DeprecatedEndpointUpdateApi(Resource):
"""Deprecated verb-based alias for endpoint updates."""
@console_ns.doc("update_endpoint_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(
description=(
"Deprecated legacy alias for updating a plugin endpoint. "
"Use PATCH /workspaces/current/endpoints/{id} instead."
)
)
@console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__])
@console_ns.response(
200,
"Endpoint updated successfully",
@ -233,19 +357,8 @@ class EndpointUpdateApi(Resource):
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=args.endpoint_id,
name=args.name,
settings=args.settings,
)
}
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
return _update_endpoint(endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/enable")

View File

@ -23,9 +23,11 @@ from .app import (
conversation,
file,
file_preview,
human_input_form,
message,
site,
workflow,
workflow_events,
)
from .dataset import (
dataset,
@ -50,6 +52,7 @@ __all__ = [
"file",
"file_preview",
"hit_testing",
"human_input_form",
"index",
"message",
"metadata",
@ -58,6 +61,7 @@ __all__ = [
"segment",
"site",
"workflow",
"workflow_events",
]
api.add_namespace(service_api_ns)

View File

@ -0,0 +1,137 @@
"""
Service API human input form endpoints.
This module exposes app-token authenticated APIs for fetching and submitting
paused human input forms in workflow/chatflow runs.
"""
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
# Keep app-token callers scoped to the public web-form surface; internal HITL
# routes must continue to flow through console-only authentication.
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
raise NotFound("Form not found")
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token
def get(self, app_model: App, form_token: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form submitted successfully",
400: "Bad request - invalid submission data",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, form_token: str):
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
recipient_type = form.recipient_type
if recipient_type is None:
logger.warning("Recipient type is None for form, form_id=%s", form.id)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=end_user.id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -0,0 +1,142 @@
"""
Service API workflow resume event stream endpoints.
"""
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@service_api_ns.route("/workflow/<string:task_id>/events")
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(
params={
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": (
"Whether to replay from persisted state snapshot, "
'specify `"true"` to include a status snapshot of executed nodes'
),
"continue_on_pause": (
"Whether to keep the stream open across workflow_paused events,"
'specify `"true"` to keep the stream open for `workflow_paused` events.'
),
}
)
@service_api_ns.doc(
responses={
200: "SSE event stream",
401: "Unauthorized - invalid API token",
404: "Workflow run not found",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise NotWorkflowAppError()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFound("Workflow run not found")
if workflow_run.created_by != end_user.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

View File

@ -1,4 +1,12 @@
"""Service API endpoints for dataset document management.
The canonical Service API paths use hyphenated route segments. Legacy underscore
aliases remain registered for backward compatibility, but they must stay marked
deprecated in generated API docs so clients migrate toward the canonical paths.
"""
import json
from collections.abc import Mapping
from contextlib import ExitStack
from typing import Self
from uuid import UUID
@ -117,12 +125,137 @@ register_schema_models(
)
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/document/create_by_text",
"/datasets/<uuid:dataset_id>/document/create-by-text",
)
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
"""Create a document from text for both canonical and legacy routes."""
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
dataset_id_str = str(dataset_id)
tenant_id_str = str(tenant_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model)
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id_str,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
if not current_user:
raise ValueError("current_user is required")
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
"""Update a document from text for both canonical and legacy routes."""
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
)
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if args.get("text"):
text = args.get("text")
name = args.get("name")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
"""Resource for the canonical text document creation route."""
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text")
@ -138,81 +271,43 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
def post(self, tenant_id: str, dataset_id: UUID):
"""Create document by text."""
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create_by_text")
class DeprecatedDocumentAddByTextApi(DatasetApiResource):
"""Deprecated resource alias for text document creation."""
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
description=(
"Deprecated legacy alias for creating a new document by providing text content. "
"Use /datasets/{dataset_id}/document/create-by-text instead."
)
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
)
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@service_api_ns.doc(
responses={
200: "Document created successfully",
401: "Unauthorized - invalid API token",
400: "Bad request - invalid parameters",
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig.model_validate(args)
# validate args
DocumentService.document_create_args_validate(knowledge_config)
if not current_user:
raise ValueError("current_user is required")
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID):
"""Create document by text through the deprecated underscore alias."""
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text")
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
"""Resource for the canonical text document update route."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text")
@ -229,62 +324,35 @@ class DocumentUpdateByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text."""
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
"""Deprecated resource alias for text document updates."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
description=(
"Deprecated legacy alias for updating an existing document by providing text content. "
"Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead."
)
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if args.get("text"):
text = args.get("text")
name = args.get("name")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Document updated successfully",
401: "Unauthorized - invalid API token",
404: "Document not found",
}
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text through the deprecated underscore alias."""
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.route(

View File

@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,

View File

@ -39,7 +39,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
@ -656,7 +660,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
) -> (
ChatbotAppBlockingResponse
| AdvancedChatPausedBlockingResponse
| Generator[ChatbotAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppBlockingResponse,
AdvancedChatPausedBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
StreamEvent,
)
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
paused_data = blocking_response.data.model_dump(mode="json")
return {
"event": StreamEvent.WORKFLOW_PAUSED.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
"workflow_run_id": blocking_response.data.workflow_run_id,
"data": paused_data,
}
response = {
"event": "message",
"event": StreamEvent.MESSAGE.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
return response

View File

@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowPauseStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
AdvancedChatPausedBlockingResponse,
Generator[ChatbotAppStreamResponse, None, None],
]:
"""
Process generate task pipeline.
:return:
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
"""
Process blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> AdvancedChatPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return AdvancedChatPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=paused_nodes,
reasons=reasons,
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -1,7 +1,9 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Union, cast
from pydantic import JsonValue
from graphon.model_runtime.errors.invoke import InvokeError
@ -12,8 +14,10 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
@classmethod
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
return cast(TBlockingResponse, response)
@classmethod
def convert(
@ -21,7 +25,7 @@ class AppGenerateResponseConverter(ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -30,7 +34,7 @@ class AppGenerateResponseConverter(ABC):
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -40,12 +44,12 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@ -107,13 +111,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, Any]] = {
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@ -127,7 +131,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data: dict[str, Any] | None = None
data: dict[str, JsonValue] | None = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -65,6 +65,7 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
},
)
responses: list[StreamResponse] = []

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
)
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
response: dict[str, Any] = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,

View File

@ -1,6 +1,7 @@
from collections.abc import Callable, Generator, Mapping
from collections.abc import Callable, Generator, Iterable, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.task_entities import StreamEvent
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
@ -26,6 +27,7 @@ class MessageGenerator:
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
@ -33,4 +35,5 @@ class MessageGenerator:
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
terminal_events=terminal_events,
)

View File

@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -29,7 +29,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -59,7 +59,7 @@ def stream_topic_events(
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
if terminal_events is None:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:

View File

@ -29,7 +29,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
@ -633,7 +637,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.model_dump()
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -45,12 +45,15 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
]:
"""
Process generate task pipeline.
:return:
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
"""
To blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
),
)
return response
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> WorkflowAppPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
created_at = int(runtime_state.start_at)
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=human_input_responses[-1].workflow_run_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
created_at=created_at,
finished_at=None,
paused_nodes=paused_nodes,
reasons=reasons,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:

View File

@ -1,15 +1,16 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from typing import Any, Literal
from graphon.entities import WorkflowStartReason
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction
class AnnotationReplyAccount(BaseModel):
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
data: Data
class HumanInputRequiredPauseReasonPayload(BaseModel):
"""
Public pause-reason payload used by blocking responses when only
``human_input_required`` events are available.
"""
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@classmethod
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
return cls(
form_id=data.form_id,
node_id=data.node_id,
node_title=data.node_title,
form_content=data.form_content,
inputs=data.inputs,
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
"""
ChatbotAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
message_id: str
workflow_run_id: str
answer: str
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
status: WorkflowExecutionStatus
elapsed_time: float
total_tokens: int
total_steps: int
data: Data
class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
"""
WorkflowAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
workflow_id: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_at: int
finished_at: int | None
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Iterator
from collections.abc import Generator # Changed from Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
@contextmanager
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
token = _current_file_access_scope.set(scope)
try:
yield

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
@ -15,8 +16,21 @@ from core.provider_manager import ProviderManager
class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(
self,
@ -31,8 +45,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@ -47,6 +65,7 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@ -66,7 +85,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
),
enable_credentials_cache=True,
)
self.model_manager = model_manager
@ -85,7 +105,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@ -0,0 +1,41 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def __init__(self, instruction_prompt: str | None = None) -> None:
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
@staticmethod
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
if not instruction_prompt or not instruction_prompt.strip():
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
def get_format_instructions(self) -> str:
return self._instruction_prompt

View File

@ -107,6 +107,7 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
@ -36,11 +37,13 @@ class ModelInstance:
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager:
def __init__(self, provider_manager: ProviderManager):
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@ -70,12 +70,32 @@ class ProviderManager:
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
Configuration assembly is cached per manager instance so call chains that
share one request-scoped manager can reuse the same provider graph instead
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
when a long-lived manager needs to observe writes performed within the same
instance scope.
"""
decoding_rsa_key: Any | None
decoding_cipher_rsa: Any | None
_model_runtime: ModelRuntime
_configurations_cache: dict[str, ProviderConfigurations]
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
self._configurations_cache = {}
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
"""Drop assembled provider configurations cached on this manager instance."""
if tenant_id is None:
self._configurations_cache.clear()
return
self._configurations_cache.pop(tenant_id, None)
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@ -114,6 +134,10 @@ class ProviderManager:
:param tenant_id:
:return:
"""
cached_configurations = self._configurations_cache.get(tenant_id)
if cached_configurations is not None:
return cached_configurations
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
@ -273,6 +297,8 @@ class ProviderManager:
provider_configurations[str(provider_id_entity)] = provider_configuration
self._configurations_cache[tenant_id] = provider_configurations
# Return the encapsulated object
return provider_configurations

View File

@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
if keyword_data_source_type == "database":
if dataset_keyword_table is None:
return
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
@ -154,7 +156,8 @@ class Jieba(BaseKeyword):
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"])
data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(

View File

@ -1,4 +1,5 @@
import re
from collections.abc import Callable
from operator import itemgetter
from typing import cast
@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
top_k = kwargs.pop("topK", top_k)
top_k = cast(int | None, kwargs.pop("topK", top_k))
if top_k is None:
top_k = 20
cut = getattr(jieba, "cut", None)
if self._lcut:
tokens = self._lcut(sentence)
elif callable(cut):
tokens = list(cut(sentence))
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
else:
tokens = re.findall(r"\w+", sentence)
@ -106,9 +109,9 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
topK=max_keywords_per_chunk or 10,
)
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(set(keywords)))

View File

@ -158,7 +158,7 @@ class RetrievalService:
)
if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600):
for _ in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()

View File

@ -94,6 +94,7 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file
if not file_path:
@ -104,6 +105,7 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
assert upload_file is not None, "upload_file is required"
etl_type = dify_config.ETL_TYPE
extractor: BaseExtractor | None = None
if etl_type == "Unstructured":

View File

@ -29,10 +29,10 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
result: LLMResult = model_instance.invoke_llm(
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
stream=False, # pyright: ignore[reportArgumentType]
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import codecs
import re
from collections.abc import Collection
from collections.abc import Set as AbstractSet
from typing import Any, Literal
from core.model_manager import ModelInstance
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T],
embedding_model_instance: ModelInstance | None,
allowed_special: Literal["all"] | set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
) -> T:
def _token_encoder(texts: list[str]) -> list[int]:
@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts]
_ = _token_encoder # kept for future token-length wiring
return cls(length_function=_character_encoder, **kwargs)

View File

@ -4,7 +4,8 @@ import copy
import logging
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Set as AbstractSet
from dataclasses import dataclass
from typing import Any, Literal
@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
allowed_special: Literal["all"] | Set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
def split_text(self, text: str) -> list[str]:
def _encode(_text: str) -> list[int]:

View File

@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
class SystemEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Encrypt parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
return params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an OAuth encrypter instance.
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
_encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global OAuth encrypter instance.
Get the global encrypter instance.
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Encrypt parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
params: Parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -105,7 +105,7 @@ class Article:
def extract_using_readabilipy(html: str):
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False)
article = Article(
title=json_article.get("title") or "",
author=json_article.get("byline") or "",

View File

@ -12,20 +12,16 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
from extensions.ext_database import db
from models.human_input import HumanInputFormRecipient, RecipientType
_FORM_TOKEN_PRIORITY = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def load_form_tokens_by_form_id(
form_ids: Sequence[str],
*,
session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
"""Load the preferred access token for each human input form."""
unique_form_ids = list(dict.fromkeys(form_ids))
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
return {}
if session is not None:
return _load_form_tokens_by_form_id(session, unique_form_ids)
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
tokens_by_form_id: dict[str, tuple[int, str]] = {}
def _load_form_tokens_by_form_id(
session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt):
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
if priority is None or not recipient.access_token:
if not recipient.access_token:
continue
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token)
)
candidate = (priority, recipient.access_token)
current = tokens_by_form_id.get(recipient.form_id)
if current is None or candidate[0] < current[0]:
tokens_by_form_id[recipient.form_id] = candidate
tokens_by_form_id: dict[str, str] = {}
for form_id, recipients in recipients_by_form_id.items():
token = _get_surface_form_token(recipients, surface=surface)
if token is not None:
tokens_by_form_id[form_id] = token
return tokens_by_form_id
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
def _get_surface_form_token(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface == HumanInputSurface.SERVICE_API:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token
return get_preferred_form_token(recipients)

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from graphon.entities.pause_reason import PauseReasonType
from models.human_input import RecipientType
class HumanInputSurface(StrEnum):
SERVICE_API = "service_api"
CONSOLE = "console"
# Service API is intentionally narrower than other surfaces: app-token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
}
# A single HITL form can have multiple recipient records; this shared priority
# keeps every API surface consistent about which resume token to expose.
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def is_recipient_type_allowed_for_surface(
recipient_type: RecipientType | None,
surface: HumanInputSurface,
) -> bool:
if recipient_type is None:
return False
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
def get_preferred_form_token(
recipients: Sequence[tuple[RecipientType, str]],
) -> str | None:
chosen_token: str | None = None
chosen_priority: int | None = None
for recipient_type, token in recipients:
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
if priority is None or not token:
continue
if chosen_priority is None or priority < chosen_priority:
chosen_priority = priority
chosen_token = token
return chosen_token
def enrich_human_input_pause_reasons(
reasons: Sequence[Mapping[str, Any]],
*,
form_tokens_by_form_id: Mapping[str, str],
expiration_times_by_form_id: Mapping[str, int],
) -> list[dict[str, Any]]:
enriched: list[dict[str, Any]] = []
for reason in reasons:
updated = dict(reason)
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_id = updated.get("form_id")
if isinstance(form_id, str):
updated["form_token"] = form_tokens_by_form_id.get(form_id)
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is not None:
updated["expiration_time"] = expiration_time
enriched.append(updated)
return enriched

View File

@ -0,0 +1,172 @@
"""Generate Flask-RESTX Swagger 2.0 specs without booting the full backend.
This helper intentionally avoids `app_factory.create_app()`. The normal backend
startup eagerly initializes database, Redis, Celery, and storage extensions,
which is unnecessary when the goal is only to serialize the Flask-RESTX
`/swagger.json` documents.
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from flask import Flask
from flask_restx.swagger import Swagger
logger = logging.getLogger(__name__)
API_ROOT = Path(__file__).resolve().parents[1]
if str(API_ROOT) not in sys.path:
sys.path.insert(0, str(API_ROOT))
@dataclass(frozen=True)
class SpecTarget:
route: str
filename: str
SPEC_TARGETS: tuple[SpecTarget, ...] = (
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"),
SpecTarget(route="/api/swagger.json", filename="web-swagger.json"),
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"),
)
_ORIGINAL_REGISTER_MODEL = Swagger.register_model
_ORIGINAL_REGISTER_FIELD = Swagger.register_field
def _apply_runtime_defaults() -> None:
"""Force the small config surface required for Swagger generation."""
os.environ.setdefault("SECRET_KEY", "spec-export")
os.environ.setdefault("STORAGE_TYPE", "local")
os.environ.setdefault("STORAGE_LOCAL_PATH", "/tmp/dify-storage")
os.environ.setdefault("SWAGGER_UI_ENABLED", "true")
from configs import dify_config
dify_config.SECRET_KEY = os.environ["SECRET_KEY"]
dify_config.STORAGE_TYPE = "local"
dify_config.STORAGE_LOCAL_PATH = os.environ["STORAGE_LOCAL_PATH"]
dify_config.SWAGGER_UI_ENABLED = os.environ["SWAGGER_UI_ENABLED"].lower() == "true"
def _patch_swagger_for_inline_nested_dicts() -> None:
"""Teach Flask-RESTX Swagger generation to tolerate inline nested field maps.
Some existing controllers use `fields.Nested({...})` with a raw field mapping
instead of a named `api.model(...)`. Flask-RESTX crashes on those anonymous
dicts during schema registration, so this helper upgrades them into temporary
named models at export time.
"""
if getattr(Swagger, "_dify_inline_nested_dict_patch", False):
return
def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object:
anonymous_models = getattr(self, "_anonymous_inline_models", None)
if anonymous_models is None:
anonymous_models = {}
self._anonymous_inline_models = anonymous_models
anonymous_name = anonymous_models.get(id(nested_fields))
if anonymous_name is None:
anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}"
anonymous_models[id(nested_fields)] = anonymous_name
self.api.model(anonymous_name, nested_fields)
return self.api.models[anonymous_name]
def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]:
if isinstance(model, dict):
model = get_or_create_inline_model(self, model)
return _ORIGINAL_REGISTER_MODEL(self, model)
def register_field_with_inline_dict_support(self: Swagger, field: object) -> None:
nested = getattr(field, "nested", None)
if isinstance(nested, dict):
field.model = get_or_create_inline_model(self, nested) # type: ignore
_ORIGINAL_REGISTER_FIELD(self, field)
Swagger.register_model = register_model_with_inline_dict_support
Swagger.register_field = register_field_with_inline_dict_support
Swagger._dify_inline_nested_dict_patch = True
def create_spec_app() -> Flask:
"""Build a minimal Flask app that only mounts the Swagger-producing blueprints."""
_apply_runtime_defaults()
_patch_swagger_for_inline_nested_dicts()
app = Flask(__name__)
from controllers.console import bp as console_bp
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
app.register_blueprint(console_bp)
app.register_blueprint(web_bp)
app.register_blueprint(service_api_bp)
return app
def generate_specs(output_dir: Path) -> list[Path]:
"""Write all Swagger specs to `output_dir` and return the written paths."""
output_dir.mkdir(parents=True, exist_ok=True)
app = create_spec_app()
client = app.test_client()
written_paths: list[Path] = []
for target in SPEC_TARGETS:
response = client.get(target.route)
if response.status_code != 200:
raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}")
payload = response.get_json()
if not isinstance(payload, dict):
raise RuntimeError(f"unexpected response payload for {target.route}")
output_path = output_dir / target.filename
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
written_paths.append(output_path)
return written_paths
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"-o",
"--output-dir",
type=Path,
default=Path("openapi"),
help="Directory where the Swagger JSON files will be written.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
written_paths = generate_specs(args.output_dir)
for path in written_paths:
logger.debug(path)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -1,5 +1,5 @@
import contextvars
from collections.abc import Iterator
from collections.abc import Generator # Changed from Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
@ -13,7 +13,7 @@ if TYPE_CHECKING:
def preserve_flask_contexts(
flask_app: Flask,
context_vars: contextvars.Context,
) -> Iterator[None]:
) -> Generator[None, None, None]: # Changed from Iterator[None]
"""
A context manager that handles:
1. flask-login's UserProxy copy

View File

@ -42,8 +42,7 @@ class WorkflowComment(Base):
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(
StringUUID, server_default=sa.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(sa.Float)
@ -152,8 +151,7 @@ class WorkflowCommentReply(Base):
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(
StringUUID, server_default=sa.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
@ -200,8 +198,7 @@ class WorkflowCommentMention(Base):
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(
StringUUID, server_default=sa.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)

View File

@ -1,6 +1,6 @@
import json
import uuid
from collections.abc import Iterator
from collections.abc import Generator # Added Generator
from contextlib import contextmanager
from typing import Any
@ -75,7 +75,7 @@ class AnalyticdbVectorBySql:
)
@contextmanager
def _get_cursor(self) -> Iterator[Any]:
def _get_cursor(self) -> Generator[Any, None, None]: # Changed from Iterator[Any]
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()

View File

@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options)
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
from packaging import version
from pydantic import BaseModel, model_validator
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None:
# Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name)
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
return False
try:
milvus_version = self._client.get_server_version()
milvus_version_raw = self._client.get_server_version()
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version:
return True

View File

@ -3,7 +3,7 @@ import json
import logging
import re
import uuid
from typing import Any
from typing import Any, TypedDict
import jieba.posseg as pseg # type: ignore
import numpy
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False
class _OraclePoolParams(TypedDict, total=False):
user: str
password: str
dsn: str
min: int
max: int
increment: int
config_dir: str | None
wallet_location: str | None
wallet_password: str | None
class OracleVectorConfig(BaseModel):
user: str
password: str
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
return connection
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
"max": 5,
"increment": 1,
}
pool_params = _OraclePoolParams(
user=config.user,
password=config.password,
dsn=config.dsn,
min=1,
max=5,
increment=1,
)
if config.is_autonomous:
pool_params.update(
{
"config_dir": config.config_dir,
"wallet_location": config.wallet_location,
"wallet_password": config.wallet_password,
}
)
pool_params["config_dir"] = config.config_dir
pool_params["wallet_location"] = config.wallet_location
pool_params["wallet_password"] = config.wallet_password
return oracledb.create_pool(**pool_params)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):

View File

@ -6,7 +6,7 @@ requires-python = "~=3.12.0"
dependencies = [
# Legacy: mature and widely deployed
"bleach>=6.3.0",
"boto3>=1.42.88",
"boto3>=1.42.91",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask-cors>=6.0.2",
@ -30,7 +30,7 @@ dependencies = [
"flask-migrate>=4.1.0,<5.0.0",
"flask-orjson>=2.0.0,<3.0.0",
"flask-restx>=1.3.2,<2.0.0",
"google-cloud-aiplatform>=1.147.0,<2.0.0",
"google-cloud-aiplatform>=1.148.1,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
"opentelemetry-distro>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
@ -46,7 +46,7 @@ dependencies = [
"fastopenapi[flask]~=0.7.0",
"graphon~=0.2.2",
"httpx-sse~=0.4.0",
"json-repair~=0.59.2",
"json-repair~=0.59.4",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@ -114,10 +114,10 @@ override-dependencies = [
dev = [
"coverage>=7.13.4",
"dotenv-linter>=0.7.0",
"faker>=20.1.0",
"faker>=40.15.0",
"lxml-stubs>=0.5.1",
"basedpyright>=1.39.0",
"ruff>=0.15.10",
"basedpyright>=1.39.3",
"ruff>=0.15.11",
"pytest>=9.0.3",
"pytest-benchmark>=5.2.3",
"pytest-cov>=7.1.0",
@ -157,14 +157,14 @@ dev = [
"types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0",
"boto3-stubs>=1.42.88",
"boto3-stubs>=1.42.92",
"types-jmespath>=1.1.0.20260408",
"hypothesis>=6.151.12",
"hypothesis>=6.152.1",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=2.0.0.20260408",
"types_setuptools>=82.0.0.20260408",
"pandas-stubs>=3.0.0",
"scipy-stubs>=1.15.3.0",
"scipy-stubs>=1.17.1.4",
"types-python-http-client>=3.3.7.20260408",
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
@ -173,8 +173,8 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.61.1",
"xinference-client>=2.4.0",
"pyrefly>=0.62.0",
"xinference-client>=2.5.0",
]
############################################################
@ -183,13 +183,13 @@ dev = [
############################################################
storage = [
"azure-storage-blob>=12.28.0",
"bce-python-sdk>=0.9.69",
"bce-python-sdk>=0.9.70",
"cos-python-sdk-v5>=1.9.41",
"esdk-obs-python>=3.22.2",
"google-cloud-storage>=3.10.1",
"opendal>=0.46.0",
"oss2>=2.19.1",
"supabase>=2.18.1",
"supabase>=2.28.3",
"tos>=2.9.0",
]
@ -272,7 +272,7 @@ vdb-vastbase = ["dify-vdb-vastbase"]
vdb-vikingdb = ["dify-vdb-vikingdb"]
vdb-weaviate = ["dify-vdb-weaviate"]
# Optional client used by some tests / integrations (not a vector backend plugin)
vdb-xinference = ["xinference-client>=2.4.0"]
vdb-xinference = ["xinference-client>=2.5.0"]
trace-all = [
"dify-trace-aliyun",

View File

@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from models.enums import WorkflowRunTriggeredFrom
from models.human_input import HumanInputForm
from models.human_input import HumanInputForm, HumanInputFormRecipient
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.entities.workflow_pause import WorkflowPauseEntity
@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
def _build_human_input_required_reason(
reason_model: WorkflowPauseReason,
form_model: HumanInputForm | None,
recipients: Sequence[HumanInputFormRecipient] = (),
) -> HumanInputRequired:
form_content = ""
inputs = []
@ -89,7 +90,7 @@ def _build_human_input_required_reason(
resolved_default_values = dict(definition.default_values)
node_title = definition.node_title or node_title
return HumanInputRequired(
reason = HumanInputRequired(
form_id=form_id,
form_content=form_content,
inputs=inputs,
@ -98,6 +99,7 @@ def _build_human_input_required_reason(
node_title=node_title,
resolved_default_values=resolved_default_values,
)
return reason
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
for form in session.scalars(form_stmt).all():
form_models[form.id] = form
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(recipient_stmt).all():
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
pause_reasons: list[PauseReason] = []
for reason in pause_reason_models:
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_model = form_models.get(reason.form_id)
pause_reasons.append(_build_human_input_required_reason(reason, form_model))
pause_reasons.append(
_build_human_input_required_reason(
reason,
form_model,
recipients_by_form_id.get(reason.form_id, ()),
)
)
else:
pause_reasons.append(reason.to_entity())
return pause_reasons

View File

@ -112,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
# Phase-bound token metadata for the change-email flow. Tokens carry the
# current phase so that downstream endpoints can enforce proper progression
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
CHANGE_EMAIL_PHASE_OLD = "old_email"
CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified"
CHANGE_EMAIL_PHASE_NEW = "new_email"
CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified"
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
@ -576,13 +584,20 @@ class AccountService:
raise ValueError("Email must be provided.")
if not phase:
raise ValueError("phase must be provided.")
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
raise ValueError("phase must be one of old_email or new_email.")
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
code, token = cls.generate_change_email_token(
account_email,
account,
old_email=old_email,
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
)
send_change_mail_task.delay(
language=language,

View File

@ -164,6 +164,7 @@ class AppGenerateService:
invoke_from=invoke_from,
streaming=True,
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
@ -185,6 +186,10 @@ class AppGenerateService:
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
@ -196,6 +201,7 @@ class AppGenerateService:
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
pause_state_config=pause_config,
)
),
request_id=request_id,

View File

@ -5,6 +5,7 @@ import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService:
@classmethod
@ttl_cache(ttl=5)
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")

View File

@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
trial_models: list[str] = []
enable_creators_platform: bool = False
enable_trial_app: bool = False
enable_explore_banner: bool = False
@ -241,6 +242,9 @@ class FeatureService:
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
system_features.enable_creators_platform = True
return system_features
@classmethod

View File

@ -2,7 +2,7 @@ import base64
import hashlib
import os
import uuid
from collections.abc import Iterator, Sequence
from collections.abc import Generator, Sequence # Changed Iterator to Generator
from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile
from typing import Literal
@ -324,7 +324,7 @@ class FileService:
def build_upload_files_zip_tempfile(
*,
upload_files: Sequence[UploadFile],
) -> Iterator[str]:
) -> Generator[str, None, None]: # Changed from Iterator[str]
"""
Build a ZIP from `UploadFile`s and yield a tempfile path.

View File

@ -1,10 +1,10 @@
import json
import logging
import time
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
from core.app.app_config.entities import ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -36,6 +36,10 @@ default_retrieval_model = {
}
class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
metadata_filtering_conditions: dict[str, Any]
class HitTestingService:
@classmethod
def retrieve(
@ -51,17 +55,18 @@ class HitTestingService:
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model or default_retrieval_model
assert isinstance(retrieval_model, dict)
resolved_retrieval_model = cast(
HitTestingRetrievalModelDict,
retrieval_model or dataset.retrieval_model or default_retrieval_model,
)
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions and query:
metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions_raw and query:
dataset_retrieval = DatasetRetrieval()
from core.rag.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],
@ -78,19 +83,21 @@ class HitTestingService:
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
retrieval_method=RetrievalMethod(
resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
),
dataset_id=dataset.id,
query=query,
attachment_ids=attachment_ids,
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
top_k=resolved_retrieval_model.get("top_k", 4),
score_threshold=resolved_retrieval_model.get("score_threshold", 0.0)
if resolved_retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
reranking_model=resolved_retrieval_model.get("reranking_model", None)
if resolved_retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model",
weights=resolved_retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)

View File

@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
@ -521,7 +521,7 @@ class BuiltinToolManageService:
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
@ -635,7 +635,7 @@ class TriggerProviderService:
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -18,6 +18,7 @@ from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import (
HumanInputRequiredResponse,
MessageReplaceStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@ -26,6 +27,10 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from graphon.entities.pause_reason import PauseReasonType
from models.human_input import HumanInputForm
from models.model import AppMode, Message
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
@ -59,8 +64,10 @@ def build_workflow_event_stream(
tenant_id: str,
app_id: str,
session_maker: sessionmaker[Session],
human_input_surface: HumanInputSurface | None = None,
idle_timeout: float = 300,
ping_interval: float = 10.0,
close_on_pause: bool = True,
) -> Generator[Mapping[str, Any] | str, None, None]:
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@ -115,13 +122,15 @@ def build_workflow_event_stream(
message_context=message_context,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
human_input_surface=human_input_surface,
)
for event in snapshot_events:
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
if _is_terminal_event(event, close_on_pause=close_on_pause):
return
while True:
@ -146,7 +155,7 @@ def build_workflow_event_stream(
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
if _is_terminal_event(event, close_on_pause=close_on_pause):
return
finally:
buffer_state.stop_event.set()
@ -207,6 +216,8 @@ def _build_snapshot_events(
message_context: MessageContext | None,
pause_entity: WorkflowPauseEntity | None,
resumption_context: WorkflowResumptionContext | None,
session_maker: sessionmaker[Session] | None = None,
human_input_surface: HumanInputSurface | None = None,
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
@ -241,12 +252,24 @@ def _build_snapshot_events(
events.append(node_finished)
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
for human_input_event in _build_human_input_required_events(
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
session_maker=session_maker,
human_input_surface=human_input_surface,
):
_apply_message_context(human_input_event, message_context)
events.append(human_input_event)
pause_event = _build_pause_event(
workflow_run=workflow_run,
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
human_input_surface=human_input_surface,
)
if pause_event is not None:
_apply_message_context(pause_event, message_context)
@ -314,6 +337,97 @@ def _build_node_started_event(
return response.to_ignore_detail_dict()
def _build_human_input_required_events(
*,
workflow_run_id: str,
task_id: str,
pause_entity: WorkflowPauseEntity,
session_maker: sessionmaker[Session] | None,
human_input_surface: HumanInputSurface | None,
) -> list[dict[str, Any]]:
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
human_input_form_ids = [
form_id
for reason in reasons
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
for form_id in [reason.get("form_id")]
if isinstance(form_id, str)
]
expiration_times_by_form_id: dict[str, int] = {}
display_in_ui_by_form_id: dict[str, bool] = {}
form_tokens_by_form_id: dict[str, str] = {}
if human_input_form_ids and session_maker is not None:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where(
HumanInputForm.id.in_(human_input_form_ids)
)
with session_maker() as session:
for form_id, expiration_time, form_definition in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
try:
definition_payload = json.loads(form_definition) if form_definition else {}
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_tokens_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
)
events: list[dict[str, Any]] = []
for reason in reasons:
if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED:
continue
form_id_raw = reason.get("form_id")
node_id_raw = reason.get("node_id")
node_title_raw = reason.get("node_title")
form_content_raw = reason.get("form_content")
if not isinstance(form_id_raw, str):
continue
if not isinstance(node_id_raw, str):
continue
if not isinstance(node_title_raw, str):
continue
if not isinstance(form_content_raw, str):
continue
form_id = form_id_raw
node_id = node_id_raw
node_title = node_title_raw
form_content = form_content_raw
inputs = reason.get("inputs")
actions = reason.get("actions")
resolved_default_values = reason.get("resolved_default_values")
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is None:
continue
response = HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=HumanInputRequiredResponse.Data(
form_id=form_id,
node_id=node_id,
node_title=node_title,
form_content=form_content,
inputs=inputs if isinstance(inputs, list) else [],
actions=actions if isinstance(actions, list) else [],
display_in_ui=display_in_ui_by_form_id.get(form_id, False),
form_token=form_tokens_by_form_id.get(form_id),
resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}),
expiration_time=expiration_time,
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
events.append(payload)
return events
def _build_node_finished_event(
*,
workflow_run_id: str,
@ -356,6 +470,8 @@ def _build_pause_event(
task_id: str,
pause_entity: WorkflowPauseEntity,
resumption_context: WorkflowResumptionContext | None,
session_maker: sessionmaker[Session] | None,
human_input_surface: HumanInputSurface | None = None,
) -> dict[str, Any] | None:
paused_nodes: list[str] = []
outputs: dict[str, Any] = {}
@ -365,6 +481,36 @@ def _build_pause_event(
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
human_input_form_ids = [
form_id
for reason in reasons
if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED
for form_id in [reason.get("form_id")]
if isinstance(form_id, str)
]
form_tokens_by_form_id: dict[str, str] = {}
expiration_times_by_form_id: dict[str, int] = {}
if human_input_form_ids and session_maker is not None:
with session_maker() as session:
form_tokens_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=human_input_surface,
)
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
for row in session.execute(stmt):
form_id, expiration_time, *_rest = row
expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp())
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
reasons = enrich_human_input_pause_reasons(
reasons,
form_tokens_by_form_id=form_tokens_by_form_id,
expiration_times_by_form_id=expiration_times_by_form_id,
)
response = WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
return event
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
def _is_terminal_event(
event: Mapping[str, Any] | str,
close_on_pause: bool = True,
*,
include_paused: bool | None = None,
) -> bool:
if include_paused is not None:
close_on_pause = include_paused
if not isinstance(event, Mapping):
return False
event_type = event.get("event")
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
return True
if include_paused:
if close_on_pause:
return event_type == StreamEvent.WORKFLOW_PAUSED.value
return False

View File

@ -399,6 +399,8 @@ def _resume_advanced_chat(
workflow_run_id: str,
workflow_run: WorkflowRun,
) -> None:
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@ -426,7 +428,7 @@ def _resume_advanced_chat(
user=user,
conversation=conversation,
message=message,
application_generate_entity=generate_entity,
application_generate_entity=resumed_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_runtime_state=graph_runtime_state,
@ -436,9 +438,8 @@ def _resume_advanced_chat(
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT)
def _resume_workflow(
@ -455,6 +456,8 @@ def _resume_workflow(
workflow_run_repo,
pause_entity,
) -> None:
resumed_generate_entity = generate_entity.model_copy(update={"stream": True})
try:
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
except ValueError:
@ -480,7 +483,7 @@ def _resume_workflow(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=generate_entity,
application_generate_entity=resumed_generate_entity,
graph_runtime_state=graph_runtime_state,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
@ -490,11 +493,18 @@ def _resume_workflow(
logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id)
raise
if generate_entity.stream:
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
assert isinstance(response, Generator)
_publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW)
workflow_run_repo.delete_workflow_pause(pause_entity)
try:
workflow_run_repo.delete_workflow_pause(pause_entity)
except Exception as exc:
if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc):
raise
logger.info(
"Skipped deleting workflow pause %s after resume because it was already replaced or removed",
pause_entity.id,
)
@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution")

View File

@ -2,27 +2,31 @@
from __future__ import annotations
import secrets
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
from uuid import uuid4
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.human_input_adapter import DeliveryMethodType
from extensions.ext_storage import storage
from graphon.entities import WorkflowExecution
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
from graphon.enums import WorkflowExecutionStatus
from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.human_input import (
BackstageRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
@ -628,12 +632,12 @@ class TestPrivateWorkflowPauseEntity:
class TestBuildHumanInputRequiredReason:
"""Integration tests for _build_human_input_required_reason using real DB models."""
def test_builds_reason_from_form_definition(
def test_prefers_standalone_web_app_token_when_available(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Build the graph pause reason from the stored form definition."""
"""Use the public standalone web-app token for service API payloads."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
@ -660,6 +664,40 @@ class TestBuildHumanInputRequiredReason:
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
backstage_access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=backstage_access_token,
)
console_access_token = secrets.token_urlsafe(8)
console_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.CONSOLE,
recipient_payload="{}",
access_token=console_access_token,
)
web_app_access_token = secrets.token_urlsafe(8)
web_app_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload="{}",
access_token=web_app_access_token,
)
db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient])
db_session_with_containers.flush()
# Create a pause so the reason has a valid pause_id
workflow_run = _create_workflow_run(
db_session_with_containers,
@ -688,8 +726,15 @@ class TestBuildHumanInputRequiredReason:
# Refresh to ensure we have DB-round-tripped objects
db_session_with_containers.refresh(form_model)
db_session_with_containers.refresh(reason_model)
db_session_with_containers.refresh(backstage_recipient)
db_session_with_containers.refresh(console_recipient)
db_session_with_containers.refresh(web_app_recipient)
reason = _build_human_input_required_reason(reason_model, form_model)
reason = _build_human_input_required_reason(
reason_model,
form_model,
[backstage_recipient, console_recipient, web_app_recipient],
)
assert isinstance(reason, HumanInputRequired)
assert reason.node_title == "Ask Name"
@ -697,3 +742,92 @@ class TestBuildHumanInputRequiredReason:
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"
assert reason.resolved_default_values == {"name": "Alice"}
assert not hasattr(reason, "form_token")
def test_falls_back_to_console_token_when_web_app_token_missing(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Use the console token only when no standalone web-app token exists."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_run_id=str(uuid4()),
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
backstage_access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=backstage_access_token,
)
console_access_token = secrets.token_urlsafe(8)
console_recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.CONSOLE,
recipient_payload="{}",
access_token=console_access_token,
)
db_session_with_containers.add_all([backstage_recipient, console_recipient])
db_session_with_containers.flush()
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.flush()
test_scope.state_keys.add(pause.state_object_key)
reason_model = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=form_model.id,
node_id="node-1",
message="",
)
db_session_with_containers.add(reason_model)
db_session_with_containers.commit()
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient])
assert isinstance(reason, HumanInputRequired)
assert not hasattr(reason, "form_token")

View File

@ -0,0 +1,37 @@
"""Unit tests for the standalone Swagger export helper."""
import importlib.util
import json
import sys
from pathlib import Path
def _load_generate_swagger_specs_module():
api_dir = Path(__file__).resolve().parents[3]
script_path = api_dir / "dev" / "generate_swagger_specs.py"
spec = importlib.util.spec_from_file_location("generate_swagger_specs", script_path)
assert spec
assert spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
return module
def test_generate_specs_writes_console_web_and_service_swagger_files(tmp_path):
module = _load_generate_swagger_specs_module()
written_paths = module.generate_specs(tmp_path)
assert [path.name for path in written_paths] == [
"console-swagger.json",
"web-swagger.json",
"service-swagger.json",
]
for path in written_paths:
payload = json.loads(path.read_text(encoding="utf-8"))
assert payload["swagger"] == "2.0"
assert "paths" in payload

View File

@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch)
handler(api, form_token="token")
def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)

View File

@ -68,7 +68,10 @@ class TestChangeEmailSend:
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {"email": "current@example.com"}
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
}
mock_send_email.return_value = "token-abc"
with app.test_request_context(
@ -85,12 +88,55 @@ class TestChangeEmailSend:
email="new@example.com",
old_email="current@example.com",
language="en-US",
phase="new_email",
phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app,
):
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailSendEmailApi().post()
mock_send_email.assert_not_called()
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@ -122,7 +168,12 @@ class TestChangeEmailValidity:
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
mock_get_data.return_value = {
"email": "user@example.com",
"code": "1234",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
@ -138,11 +189,169 @@ class TestChangeEmailValidity:
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
"user@example.com",
code="1234",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
},
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_upgrade_new_phase_token_to_new_verified(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {
"email": "new@example.com",
"code": "1234",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
}
mock_generate_token.return_value = (None, "new-verified-token")
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "new@example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"}
mock_generate_token.assert_called_once_with(
"new@example.com",
code="1234",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
},
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_validity_when_token_phase_is_unknown(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {
"email": "user@example.com",
"code": "1234",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else",
}
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "user@example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailCheckApi().post()
mock_revoke_token.assert_not_called()
mock_generate_token.assert_not_called()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_validity_when_token_has_no_phase(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {
"email": "user@example.com",
"code": "1234",
"old_email": "old@example.com",
}
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "user@example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailCheckApi().post()
mock_revoke_token.assert_not_called()
mock_generate_token.assert_not_called()
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@ -175,7 +384,11 @@ class TestChangeEmailReset:
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {"old_email": "OLD@example.com"}
mock_get_data.return_value = {
"email": "new@example.com",
"old_email": "OLD@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
@ -194,6 +407,155 @@ class TestChangeEmailReset:
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_reset_when_token_phase_is_not_new_verified(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
# Simulate a token straight out of step #1 (phase=old_email) — exactly
# the replay used in the advisory PoC.
mock_get_data.return_value = {
"email": "old@example.com",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "attacker@example.com", "token": "token-from-step1"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {
"email": "verified@example.com",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "attacker@example.com", "token": "token-verified"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
class TestAccountServiceSendChangeEmailEmail:
"""Service-level coverage for the phase-bound changes in `send_change_email_email`."""
def test_should_raise_value_error_for_invalid_phase(self):
with pytest.raises(ValueError, match="phase must be one of"):
AccountService.send_change_email_email(
email="user@example.com",
old_email="user@example.com",
phase="old_email_verified",
)
@patch("services.account_service.send_change_mail_task")
@patch("services.account_service.AccountService.change_email_rate_limiter")
@patch("services.account_service.AccountService.generate_change_email_token")
def test_should_stamp_phase_into_generated_token(
self,
mock_generate_token,
mock_rate_limiter,
mock_mail_task,
):
mock_rate_limiter.is_rate_limited.return_value = False
mock_generate_token.return_value = ("123456", "the-token")
returned = AccountService.send_change_email_email(
email="user@example.com",
old_email="user@example.com",
language="en-US",
phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
)
assert returned == "the-token"
mock_generate_token.assert_called_once_with(
"user@example.com",
None,
old_email="user@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
},
)
mock_mail_task.delay.assert_called_once()
mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com")
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")

View File

@ -2,14 +2,17 @@ from unittest.mock import MagicMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.workspace.endpoint import (
EndpointCreateApi,
EndpointDeleteApi,
DeprecatedEndpointCreateApi,
DeprecatedEndpointDeleteApi,
DeprecatedEndpointUpdateApi,
EndpointCollectionApi,
EndpointDisableApi,
EndpointEnableApi,
EndpointItemApi,
EndpointListApi,
EndpointListForSinglePluginApi,
EndpointUpdateApi,
)
from core.plugin.impl.exc import PluginPermissionDeniedError
@ -35,9 +38,9 @@ def patch_current_account(user_and_tenant):
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCreateApi:
class TestEndpointCollectionApi:
def test_create_success(self, app):
api = EndpointCreateApi()
api = EndpointCollectionApi()
method = unwrap(api.post)
payload = {
@ -55,7 +58,7 @@ class TestEndpointCreateApi:
assert result["success"] is True
def test_create_permission_denied(self, app):
api = EndpointCreateApi()
api = EndpointCollectionApi()
method = unwrap(api.post)
payload = {
@ -75,7 +78,7 @@ class TestEndpointCreateApi:
method(api)
def test_create_validation_error(self, app):
api = EndpointCreateApi()
api = EndpointCollectionApi()
method = unwrap(api.post)
payload = {
@ -91,6 +94,27 @@ class TestEndpointCreateApi:
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointCreateApi:
def test_create_success(self, app):
api = DeprecatedEndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {"a": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
@ -146,9 +170,96 @@ class TestEndpointListForSinglePluginApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDeleteApi:
class TestEndpointItemApi:
def test_delete_success(self, app):
api = EndpointDeleteApi()
api = EndpointItemApi()
method = unwrap(api.delete)
with (
app.test_request_context("/", method="DELETE"),
patch(
"controllers.console.workspace.endpoint.EndpointService.delete_endpoint",
return_value=True,
) as mock_delete,
):
result = method(api, "e1")
assert result["success"] is True
mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1")
def test_delete_service_failure(self, app):
api = EndpointItemApi()
method = unwrap(api.delete)
with (
app.test_request_context("/", method="DELETE"),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
):
result = method(api, "e1")
assert result["success"] is False
def test_update_success(self, app):
api = EndpointItemApi()
method = unwrap(api.patch)
payload = {
"name": "new-name",
"settings": {"x": 1},
}
with (
app.test_request_context("/", method="PATCH", json=payload),
patch(
"controllers.console.workspace.endpoint.EndpointService.update_endpoint",
return_value=True,
) as mock_update,
):
result = method(api, "e1")
assert result["success"] is True
mock_update.assert_called_once_with(
tenant_id="t1",
user_id="u1",
endpoint_id="e1",
name="new-name",
settings={"x": 1},
)
def test_update_validation_error(self, app):
api = EndpointItemApi()
method = unwrap(api.patch)
payload = {"settings": {}}
with (
app.test_request_context("/", method="PATCH", json=payload),
):
with pytest.raises(ValueError):
method(api, "e1")
def test_update_service_failure(self, app):
api = EndpointItemApi()
method = unwrap(api.patch)
payload = {
"name": "n",
"settings": {},
}
with (
app.test_request_context("/", method="PATCH", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
):
result = method(api, "e1")
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointDeleteApi:
def test_delete_success(self, app):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
@ -162,7 +273,7 @@ class TestEndpointDeleteApi:
assert result["success"] is True
def test_delete_invalid_payload(self, app):
api = EndpointDeleteApi()
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
with (
@ -172,7 +283,7 @@ class TestEndpointDeleteApi:
method(api)
def test_delete_service_failure(self, app):
api = EndpointDeleteApi()
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
@ -187,9 +298,9 @@ class TestEndpointDeleteApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointUpdateApi:
class TestDeprecatedEndpointUpdateApi:
def test_update_success(self, app):
api = EndpointUpdateApi()
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
payload = {
@ -207,7 +318,7 @@ class TestEndpointUpdateApi:
assert result["success"] is True
def test_update_validation_error(self, app):
api = EndpointUpdateApi()
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1", "settings": {}}
@ -219,7 +330,7 @@ class TestEndpointUpdateApi:
method(api)
def test_update_service_failure(self, app):
api = EndpointUpdateApi()
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
payload = {
@ -237,6 +348,36 @@ class TestEndpointUpdateApi:
assert result["success"] is False
class TestEndpointRouteMetadata:
def test_legacy_write_routes_are_marked_deprecated(self):
assert DeprecatedEndpointCreateApi.post.__apidoc__["deprecated"] is True
assert DeprecatedEndpointDeleteApi.post.__apidoc__["deprecated"] is True
assert DeprecatedEndpointUpdateApi.post.__apidoc__["deprecated"] is True
assert EndpointCollectionApi.post.__apidoc__.get("deprecated") is not True
assert EndpointItemApi.delete.__apidoc__.get("deprecated") is not True
assert EndpointItemApi.patch.__apidoc__.get("deprecated") is not True
def test_canonical_and_legacy_write_routes_are_registered(self):
route_map = {
resource.__name__: urls
for resource, urls, _route_doc, _kwargs in console_ns.resources
if resource.__name__
in {
"EndpointCollectionApi",
"EndpointItemApi",
"DeprecatedEndpointCreateApi",
"DeprecatedEndpointDeleteApi",
"DeprecatedEndpointUpdateApi",
}
}
assert route_map["EndpointCollectionApi"] == ("/workspaces/current/endpoints",)
assert route_map["EndpointItemApi"] == ("/workspaces/current/endpoints/<string:id>",)
assert route_map["DeprecatedEndpointCreateApi"] == ("/workspaces/current/endpoints/create",)
assert route_map["DeprecatedEndpointDeleteApi"] == ("/workspaces/current/endpoints/delete",)
assert route_map["DeprecatedEndpointUpdateApi"] == ("/workspaces/current/endpoints/update",)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):

View File

@ -0,0 +1,707 @@
"""Dedicated tests for HITL behavior exposed through the Service API."""
from __future__ import annotations
import json
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock
import pytest
import services.app_generate_service as ags_module
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.common import workflow_response_converter
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
HumanInputRequiredResponse,
WorkflowAppPausedBlockingResponse,
WorkflowPauseStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.human_input_policy import HumanInputSurface
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from graphon.nodes.human_input.entities import FormInput, UserAction
from graphon.nodes.human_input.enums import FormInputType
from graphon.runtime import GraphRuntimeState, VariablePool
from models.account import Account
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.app_generate_service import AppGenerateService
from services.workflow_event_snapshot_service import _build_snapshot_events
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class _DummyRateLimit:
@staticmethod
def gen_request_key() -> str:
return "dummy-request-id"
def __init__(self, client_id: str, max_active_requests: int) -> None:
self.client_id = client_id
self.max_active_requests = max_active_requests
def enter(self, request_id: str | None = None) -> str:
return request_id or "dummy-request-id"
def exit(self, request_id: str) -> None:
return None
def generate(self, generator, request_id: str):
return generator
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
monkeypatch.setattr(
workflow_events_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
return workflow_events_module
def _build_service_api_pause_converter() -> WorkflowResponseConverter:
application_generate_entity = SimpleNamespace(
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
)
system_variables = build_system_variables(
user_id="user",
app_id="app-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
)
user = MagicMock(spec=Account)
user.id = "account-id"
user.name = "Tester"
user.email = "tester@example.com"
return WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=system_variables,
)
def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse:
data = AdvancedChatPausedBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
workflow_run_id="run-1",
answer="partial",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
paused_nodes=["node-1"],
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"expiration_time": 100,
}
],
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
)
return AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse:
return WorkflowAppPausedBlockingResponse(
task_id="t1",
workflow_run_id="r1",
data=WorkflowAppPausedBlockingResponse.Data(
id="r1",
workflow_id="wf-1",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=0.5,
total_tokens=0,
total_steps=2,
created_at=1,
finished_at=None,
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}],
),
)
@dataclass(frozen=True)
class _FakePauseEntity(WorkflowPauseEntity):
pause_id: str
workflow_run_id: str
paused_at_value: datetime
pause_reasons: Sequence[HumanInputRequired]
@property
def id(self) -> str:
return self.pause_id
@property
def workflow_execution_id(self) -> str:
return self.workflow_run_id
def get_state(self) -> bytes:
raise AssertionError("state is not required for snapshot tests")
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return self.paused_at_value
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
return self.pause_reasons
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"input": "value"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
created_at = datetime(2024, 1, 1, tzinfo=UTC)
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
return WorkflowNodeExecutionSnapshot(
execution_id="exec-1",
node_id="node-1",
node_type="human-input",
title="Human Input",
index=1,
status=status.value,
elapsed_time=0.5,
created_at=created_at,
finished_at=finished_at,
iteration_id=None,
loop_id=None,
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.register_paused_node("node-1")
runtime_state.outputs = {"result": "value"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
class TestHitlServiceApi:
# Service API event-stream continuation
def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
msg_generator.retrieve_events.return_value = ["raw-event"]
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
msg_generator.retrieve_events.assert_called_once_with(
AppMode.WORKFLOW,
"run-1",
terminal_events=[],
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
self, app, monkeypatch: pytest.MonkeyPatch
) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
snapshot_builder = Mock(return_value=["snapshot-events"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true",
method="GET",
):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: snapshot\n\n"
msg_generator.retrieve_events.assert_not_called()
snapshot_builder.assert_called_once_with(
app_mode=AppMode.WORKFLOW,
workflow_run=workflow_run,
tenant_id="tenant-1",
app_id="app-1",
session_maker=ANY,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=False,
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])
def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False)
monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit)
workflow = MagicMock()
workflow.created_by = "owner-id"
monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow)
monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker")
generator_instance = MagicMock()
generator_instance.generate.return_value = {"result": "advanced-blocking"}
generator_instance.convert_to_event_stream.side_effect = lambda payload: payload
monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance)
app_model = MagicMock()
app_model.mode = AppMode.ADVANCED_CHAT
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args={"workflow_id": None, "query": "hi", "inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
assert result == {"result": "advanced-blocking"}
call_kwargs = generator_instance.generate.call_args.kwargs
assert call_kwargs["streaming"] is False
assert call_kwargs["pause_state_config"] is not None
assert call_kwargs["pause_state_config"].session_factory == "session-maker"
assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id"
# Blocking payload contract
def test_advanced_chat_blocking_pause_payload_contract(self) -> None:
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(
_build_advanced_chat_paused_blocking_response()
)
assert response["event"] == "workflow_paused"
assert response["workflow_run_id"] == "run-1"
assert response["answer"] == "partial"
assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED
assert response["data"]["reasons"][0]["expiration_time"] == 100
assert "human_input_forms" not in response["data"]
def test_workflow_blocking_pause_payload_contract(self) -> None:
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(
_build_workflow_paused_blocking_response()
)
assert response["workflow_run_id"] == "r1"
assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED
assert response["data"]["paused_nodes"] == ["node-1"]
assert response["data"]["reasons"] == [
{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}
]
assert "human_input_forms" not in response["data"]
def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None:
from core.app.app_config.entities import AppAdditionalFeatures
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from models.enums import MessageStatus
from models.model import EndUser
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.ADVANCED_CHAT,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
query="hello",
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_run_id="run-id",
)
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT),
message=SimpleNamespace(
id="message-id",
query="hello",
created_at=datetime.utcnow(),
status=MessageStatus.NORMAL,
answer="",
),
user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"),
stream=False,
dialogue_count=1,
draft_var_saver_factory=lambda **kwargs: None,
)
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Approval",
form_content="Need approval",
inputs=[],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
form_token="token-1",
resolved_default_values={},
expiration_time=123,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run-id",
paused_nodes=["node-1"],
outputs={},
reasons=[
{
"type": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"node_id": "node-1",
"expiration_time": 123,
},
],
status="paused",
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, AdvancedChatPausedBlockingResponse)
assert response.data.answer == "partial answer"
assert response.data.workflow_run_id == "run-id"
assert response.data.reasons[0]["form_id"] == "form-1"
assert response.data.reasons[0]["expiration_time"] == 123
def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}),
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
user=SimpleNamespace(id="user", session_id="session"),
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000)
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="content",
expiration_time=1,
),
)
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
paused_nodes=["node-1"],
reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}],
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, WorkflowAppPausedBlockingResponse)
assert response.data.status == WorkflowExecutionStatus.PAUSED
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}]
def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None:
converter = _build_service_api_pause_converter()
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
reason=WorkflowStartReason.INITIAL,
)
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
class _FakeSession:
def execute(self, _stmt):
return [("form-1", expiration_time, '{"display_in_ui": true}')]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
workflow_response_converter,
"load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "token"},
)
reason = HumanInputRequired(
form_id="form-1",
form_content="Rendered",
inputs=[
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
node_id="node-id",
node_title="Human Step",
form_token="token",
)
queue_event = QueueWorkflowPausedEvent(
reasons=[reason],
outputs={"answer": "value"},
paused_nodes=["node-id"],
)
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
responses = converter.workflow_pause_to_stream_response(
event=queue_event,
task_id="task",
graph_runtime_state=runtime_state,
)
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {}
assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required"
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert pause_resp.data.reasons[0]["form_token"] == "token"
assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp())
assert isinstance(responses[0], HumanInputRequiredResponse)
hi_resp = responses[0]
assert hi_resp.data.form_id == "form-1"
assert hi_resp.data.node_id == "node-id"
assert hi_resp.data.node_title == "Human Step"
assert hi_resp.data.inputs[0].output_variable_name == "field"
assert hi_resp.data.actions[0].id == "approve"
assert hi_resp.data.display_in_ui is True
assert hi_resp.data.form_token == "token"
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
# Snapshot payload contract
def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
monkeypatch.setattr(
"services.workflow_event_snapshot_service.load_form_tokens_by_form_id",
lambda form_ids, session=None, surface=None: {"form-1": "wtok"},
)
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def session_maker() -> _SessionContext:
return _SessionContext(
SimpleNamespace(
execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')],
)
)
pause_entity = _FakePauseEntity(
pause_id="pause-1",
workflow_run_id="run-1",
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
pause_reasons=[
HumanInputRequired(
form_id="form-1",
form_content="content",
node_id="node-1",
node_title="Human Input",
form_token="wtok",
)
],
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-ctx",
message_context=None,
pause_entity=pause_entity,
resumption_context=resumption_context,
session_maker=session_maker,
)
assert [event["event"] for event in events] == [
"workflow_started",
"node_started",
"node_finished",
"human_input_required",
"workflow_paused",
]
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
assert events[3]["data"]["form_token"] == "wtok"
assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
pause_data = events[-1]["data"]
assert pause_data["paused_nodes"] == ["node-1"]
assert pause_data["outputs"] == {"result": "value"}
assert pause_data["reasons"][0]["TYPE"] == "human_input_required"
assert pause_data["reasons"][0]["form_token"] == "wtok"
assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp())
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
assert pause_data["total_tokens"] == workflow_run.total_tokens
assert pause_data["total_steps"] == workflow_run.total_steps

View File

@ -0,0 +1,184 @@
"""Unit tests for Service API human input form endpoints."""
from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
from models.human_input import RecipientType
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class TestWorkflowHumanInputFormApi:
def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
definition = SimpleNamespace(
model_dump=lambda: {
"rendered_content": "Rendered form content",
"inputs": [{"output_variable_name": "name"}],
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
"user_actions": [{"id": "approve", "title": "Approve"}],
}
)
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
get_definition=lambda: definition,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
response = handler(api, app_model=app_model, form_token="token-1")
payload = json.loads(response.get_data(as_text=True))
assert payload == {
"form_content": "Rendered form content",
"inputs": [{"output_variable_name": "name"}],
"resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'},
"user_actions": [{"id": "approve", "title": "Approve"}],
"expiration_time": int(form.expiration_time.timestamp()),
}
service_mock.get_form_by_token.assert_called_once_with("token-1")
service_mock.ensure_form_active.assert_called_once_with(form)
def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(
app_id="another-app",
tenant_id="tenant-1",
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, form_token="token-1")
@pytest.mark.parametrize(
"recipient_type",
[
RecipientType.CONSOLE,
RecipientType.BACKSTAGE,
RecipientType.EMAIL_MEMBER,
RecipientType.EMAIL_EXTERNAL,
],
)
def test_get_rejects_non_service_api_recipient_types(
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=recipient_type,
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, form_token="token-1")
service_mock.ensure_form_active.assert_not_called()
def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/form/human_input/token-1",
method="POST",
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
):
response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
assert response == {}
assert status == 200
service_mock.submit_form_by_token.assert_called_once_with(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token-1",
selected_action_id="approve",
form_data={"name": "Alice"},
submission_end_user_id="end-user-1",
)
@pytest.mark.parametrize(
"recipient_type",
[
RecipientType.CONSOLE,
RecipientType.BACKSTAGE,
RecipientType.EMAIL_MEMBER,
RecipientType.EMAIL_EXTERNAL,
],
)
def test_post_rejects_non_service_api_recipient_types(
self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType
) -> None:
form = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
recipient_type=recipient_type,
)
service_mock = Mock()
service_mock.get_form_by_token.return_value = form
workflow_module = sys.modules["controllers.service_api.app.human_input_form"]
monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock)
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(
"/form/human_input/token-1",
method="POST",
json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"},
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, form_token="token-1")
service_mock.submit_form_by_token.assert_not_called()

View File

@ -0,0 +1,166 @@
"""Unit tests for Service API workflow event stream endpoints."""
from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from models.enums import CreatorUserRole
from models.model import AppMode
from tests.unit_tests.controllers.service_api.conftest import _unwrap
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"]
repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run)
monkeypatch.setattr(
workflow_events_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object()))
return workflow_events_module
class TestWorkflowEventsApi:
def test_wrong_app_mode(self, app) -> None:
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
_mock_repo_for_run(monkeypatch, workflow_run=None)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="another-user",
finished_at=None,
)
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=datetime(2099, 1, 1, tzinfo=UTC),
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
monkeypatch.setattr(
workflow_events_module.WorkflowResponseConverter,
"workflow_run_result_to_finish_response",
lambda **_kwargs: SimpleNamespace(
model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"},
event=SimpleNamespace(value="workflow_finished"),
),
)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.mimetype == "text/event-stream"
body = response.get_data(as_text=True).strip()
assert body.startswith("data: ")
payload = json.loads(body[len("data: ") :])
assert payload["task_id"] == "run-1"
assert payload["event"] == "workflow_finished"
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
msg_generator.retrieve_events.return_value = ["raw-event"]
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: streamed\n\n"
msg_generator.retrieve_events.assert_called_once_with(
AppMode.WORKFLOW,
"run-1",
terminal_events=None,
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="end-user-1",
finished_at=None,
)
workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
msg_generator = Mock()
workflow_generator = Mock()
workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"])
snapshot_builder = Mock(return_value=["snapshot-events"])
monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator)
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
assert response.get_data(as_text=True) == "data: snapshot\n\n"
msg_generator.retrieve_events.assert_not_called()
snapshot_builder.assert_called_once()
workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"])

View File

@ -22,6 +22,8 @@ import pytest
from werkzeug.exceptions import Forbidden, NotFound
from controllers.service_api.dataset.document import (
DeprecatedDocumentAddByTextApi,
DeprecatedDocumentUpdateByTextApi,
DocumentAddByFileApi,
DocumentAddByTextApi,
DocumentApi,
@ -1005,7 +1007,7 @@ class TestDocumentAddByTextApi:
# Act
with app.test_request_context(
f"/datasets/{mock_dataset.id}/document/create_by_text",
f"/datasets/{mock_dataset.id}/document/create-by-text",
method="POST",
json={
"name": "Test Document",
@ -1037,7 +1039,7 @@ class TestDocumentAddByTextApi:
# Act & Assert
with app.test_request_context(
f"/datasets/{mock_dataset.id}/document/create_by_text",
f"/datasets/{mock_dataset.id}/document/create-by-text",
method="POST",
json={"name": "Test Document", "text": "Content"},
headers={"Authorization": "Bearer test_token"},
@ -1066,7 +1068,7 @@ class TestDocumentAddByTextApi:
# Act & Assert
with app.test_request_context(
f"/datasets/{mock_dataset.id}/document/create_by_text",
f"/datasets/{mock_dataset.id}/document/create-by-text",
method="POST",
json={"name": "Test Document", "text": "Content"},
headers={"Authorization": "Bearer test_token"},
@ -1093,6 +1095,20 @@ class TestArchivedDocumentImmutableError:
assert error.code == 403
class TestDocumentTextRouteDeprecation:
"""Test that legacy underscore text routes stay marked deprecated."""
def test_create_by_text_legacy_alias_is_deprecated(self):
"""Ensure only the legacy create-by-text alias is marked deprecated."""
assert DeprecatedDocumentAddByTextApi.post.__apidoc__["deprecated"] is True
assert DocumentAddByTextApi.post.__apidoc__.get("deprecated") is not True
def test_update_by_text_legacy_alias_is_deprecated(self):
"""Ensure only the legacy update-by-text alias is marked deprecated."""
assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True
assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True
# =============================================================================
# Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi,
# DocumentUpdateByFileApi.
@ -1162,7 +1178,7 @@ class TestDocumentUpdateByTextApiPost:
doc_id = str(uuid.uuid4())
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text",
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text",
method="POST",
json={"name": "Updated Doc", "text": "New content"},
headers={"Authorization": "Bearer test_token"},
@ -1195,7 +1211,7 @@ class TestDocumentUpdateByTextApiPost:
doc_id = str(uuid.uuid4())
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text",
f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text",
method="POST",
json={"name": "Doc", "text": "Content"},
headers={"Authorization": "Bearer test_token"},

View File

@ -1,7 +1,10 @@
from collections.abc import Generator
import pytest
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@ -10,7 +13,8 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
)
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
class TestAdvancedChatGenerateResponseConverter:
@ -28,6 +32,37 @@ class TestAdvancedChatGenerateResponseConverter:
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "usage" not in response["metadata"]
def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch):
data = AdvancedChatPausedBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
workflow_run_id="run-1",
answer="partial",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
paused_nodes=["node-1"],
reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}],
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
)
original_model_dump = type(data).model_dump
def _model_dump_with_future_field(self, *args, **kwargs):
payload = original_model_dump(self, *args, **kwargs)
payload["future_field"] = "future-value"
return payload
monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field)
blocking = AdvancedChatPausedBlockingResponse(task_id="t1", data=data)
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert response["data"]["future_field"] == "future-value"
def test_stream_simple_response_includes_node_events(self):
node_start = NodeStartStreamResponse(
task_id="t1",

View File

@ -39,15 +39,19 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
AnnotationReply,
AnnotationReplyAccount,
HumanInputRequiredResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from core.workflow.system_variables import build_system_variables
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import BuiltinNodeTypes
from graphon.nodes.human_input.entities import UserAction
from graphon.runtime import GraphRuntimeState, VariablePool
from libs.datetime_utils import naive_utc_now
from models.enums import MessageStatus
@ -123,6 +127,57 @@ class TestAdvancedChatGenerateTaskPipeline:
assert response.data.answer == "done"
assert response.data.metadata == {"k": "v"}
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
start_at=0.0,
total_tokens=7,
node_run_steps=3,
)
def _gen():
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Approval",
form_content="Need approval",
inputs=[],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
form_token="token-1",
resolved_default_values={},
expiration_time=123,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, AdvancedChatPausedBlockingResponse)
assert response.data.workflow_run_id == "run-id"
assert response.data.status == "paused"
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [
{
"TYPE": PauseReasonType.HUMAN_INPUT_REQUIRED,
"form_id": "form-1",
"node_id": "node-1",
"node_title": "Approval",
"form_content": "Need approval",
"inputs": [],
"actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
"display_in_ui": True,
"form_token": "token-1",
"resolved_default_values": {},
"expiration_time": 123,
}
]
def test_handle_text_chunk_event_updates_state(self):
pipeline = _make_pipeline()
pipeline._message_cycle_manager = SimpleNamespace(

View File

@ -0,0 +1,102 @@
from __future__ import annotations
from collections.abc import Generator
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import (
AppStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
from graphon.enums import WorkflowExecutionStatus
class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
blocking_full_calls: list[WorkflowAppBlockingResponse] = []
blocking_simple_calls: list[WorkflowAppBlockingResponse] = []
stream_full_calls: list[Generator[AppStreamResponse, None, None]] = []
stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = []
@classmethod
def reset(cls) -> None:
cls.blocking_full_calls = []
cls.blocking_simple_calls = []
cls.stream_full_calls = []
cls.stream_simple_calls = []
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
cls.blocking_full_calls.append(blocking_response)
return {"kind": "blocking-full", "task_id": blocking_response.task_id}
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
cls.blocking_simple_calls.append(blocking_response)
return {"kind": "blocking-simple", "task_id": blocking_response.task_id}
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls.stream_full_calls.append(stream_response)
yield {"kind": "stream-full"}
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
cls.stream_simple_calls.append(stream_response)
yield {"kind": "stream-simple"}
def _build_blocking_response() -> WorkflowAppBlockingResponse:
return WorkflowAppBlockingResponse(
task_id="task-1",
workflow_run_id="run-1",
data=WorkflowAppBlockingResponse.Data(
id="run-1",
workflow_id="workflow-1",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"ok": True},
error=None,
elapsed_time=0.1,
total_tokens=0,
total_steps=1,
created_at=1,
finished_at=2,
),
)
def _build_stream_response() -> Generator[AppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(
workflow_run_id="run-1",
stream_response=PingStreamResponse(task_id="task-1"),
)
def test_convert_routes_blocking_response_by_invoke_from() -> None:
_DummyConverter.reset()
blocking_response = _build_blocking_response()
full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API)
simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP)
assert full_result == {"kind": "blocking-full", "task_id": "task-1"}
assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"}
assert _DummyConverter.blocking_full_calls == [blocking_response]
assert _DummyConverter.blocking_simple_calls == [blocking_response]
def test_convert_routes_stream_response_by_invoke_from() -> None:
_DummyConverter.reset()
full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API))
simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP))
assert full_result == [{"kind": "stream-full"}]
assert simple_result == [{"kind": "stream-simple"}]
assert len(_DummyConverter.stream_full_calls) == 1
assert len(_DummyConverter.stream_simple_calls) == 1

View File

@ -1,6 +1,7 @@
from unittest.mock import Mock, patch
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import StreamEvent
from models.model import AppMode
@ -23,7 +24,21 @@ class TestMessageGenerator:
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
) as mock_stream,
):
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
events = list(
MessageGenerator.retrieve_events(
AppMode.WORKFLOW,
"run-1",
idle_timeout=1,
ping_interval=2,
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
)
)
assert events == [{"event": "ping"}]
mock_stream.assert_called_once()
mock_stream.assert_called_once_with(
topic="topic",
idle_timeout=1,
ping_interval=2,
on_subscribe=None,
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
)

View File

@ -88,6 +88,10 @@ def test_normalize_terminal_events_defaults():
}
def test_normalize_terminal_events_empty_values():
assert _normalize_terminal_events([]) == set({})
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
topic = FakeTopic()
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
@ -106,3 +110,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
assert next(generator) == StreamEvent.PING.value
# next receive yields None -> ping interval triggers
assert next(generator) == StreamEvent.PING.value
def test_stream_topic_events_can_continue_past_pause():
topic = FakeTopic()
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode())
topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode())
generator = stream_topic_events(
topic=topic,
idle_timeout=1.0,
terminal_events=[StreamEvent.WORKFLOW_FINISHED.value],
)
assert next(generator) == StreamEvent.PING.value
assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value
assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value
with pytest.raises(StopIteration):
next(generator)

View File

@ -36,11 +36,12 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
WorkflowAppPausedBlockingResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
@ -91,27 +92,50 @@ def _make_pipeline():
class TestWorkflowGenerateTaskPipeline:
def test_to_blocking_response_handles_pause(self):
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
start_at=0.0,
total_tokens=5,
node_run_steps=2,
)
def _gen():
yield WorkflowPauseStreamResponse(
yield HumanInputRequiredResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
workflow_run_id="run-id",
data=HumanInputRequiredResponse.Data(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="content",
expiration_time=1,
),
)
response = pipeline._to_blocking_response(_gen())
assert isinstance(response, WorkflowAppPausedBlockingResponse)
assert response.workflow_run_id == "run-id"
assert response.data.status == WorkflowExecutionStatus.PAUSED
assert response.data.created_at == 0
assert response.data.paused_nodes == ["node-1"]
assert response.data.reasons == [
{
"TYPE": "human_input_required",
"form_id": "form-1",
"node_id": "node-1",
"node_title": "Human Input",
"form_content": "content",
"inputs": [],
"actions": [],
"display_in_ui": False,
"form_token": None,
"resolved_default_values": {},
"expiration_time": 1,
}
]
def test_to_blocking_response_handles_finish(self):
pipeline = _make_pipeline()

View File

@ -0,0 +1,106 @@
"""Tests for the Creators Platform helper module."""
from unittest.mock import MagicMock, patch
import httpx
import pytest
from yarl import URL
@pytest.fixture(autouse=True)
def _patch_creators_url(monkeypatch):
"""Patch the module-level creators_platform_api_url for all tests."""
monkeypatch.setattr(
"core.helper.creators.creators_platform_api_url",
URL("https://creators.example.com"),
)
class TestUploadDSL:
@patch("core.helper.creators.httpx.post")
def test_returns_claim_code(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = {"data": {"claim_code": "abc123"}}
mock_response.raise_for_status = MagicMock()
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
result = upload_dsl(b"app: demo", "demo.yaml")
assert result == "abc123"
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
assert "anonymous-upload" in call_kwargs.args[0]
assert call_kwargs.kwargs["timeout"] == 30
@patch("core.helper.creators.httpx.post")
def test_raises_on_missing_claim_code(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = {"data": {}}
mock_response.raise_for_status = MagicMock()
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
with pytest.raises(ValueError, match="claim_code"):
upload_dsl(b"app: demo")
@patch("core.helper.creators.httpx.post")
def test_raises_on_http_error(self, mock_post):
mock_response = MagicMock(spec=httpx.Response)
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Server Error",
request=MagicMock(),
response=MagicMock(),
)
mock_post.return_value = mock_response
from core.helper.creators import upload_dsl
with pytest.raises(httpx.HTTPStatusError):
upload_dsl(b"app: demo")
class TestGetRedirectUrl:
@patch("core.helper.creators.dify_config")
def test_without_oauth_client_id(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
assert "dsl_claim_code=claim-abc" in url
assert "oauth_code" not in url
assert url.startswith("https://creators.example.com")
@patch("core.helper.creators.dify_config")
def test_with_oauth_client_id(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz"
with patch(
"services.oauth_server.OAuthServerService.sign_oauth_authorization_code",
return_value="oauth-code-123",
) as mock_sign:
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
mock_sign.assert_called_once_with("client-xyz", "user-1")
assert "dsl_claim_code=claim-abc" in url
assert "oauth_code=oauth-code-123" in url
@patch("core.helper.creators.dify_config")
def test_strips_trailing_slash(self, mock_config):
mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/"
mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = ""
from core.helper.creators import get_redirect_url
url = get_redirect_url("user-1", "claim-abc")
assert url.startswith("https://creators.example.com?")
assert "creators.example.com/?" not in url

Some files were not shown because too many files have changed in this diff Show More