Merge branch 'main' into tp

This commit is contained in:
JzoNg 2026-04-27 10:20:08 +08:00
commit bdecea34a3
316 changed files with 12484 additions and 7806 deletions

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
api-unit:
name: API Unit Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
COVERAGE_FILE: coverage-unit
defaults:
@ -62,7 +62,7 @@ jobs:
api-integration:
name: API Integration Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
COVERAGE_FILE: coverage-integration
STORAGE_TYPE: opendal
@ -137,7 +137,7 @@ jobs:
api-coverage:
name: API Coverage
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
needs:
- api-unit
- api-integration

View File

@ -13,7 +13,7 @@ permissions:
jobs:
autofix:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'

View File

@ -26,6 +26,9 @@ jobs:
build:
runs-on: ${{ matrix.runs_on }}
if: github.repository == 'langgenius/dify'
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
@ -35,28 +38,28 @@ jobs:
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
steps:
- name: Prepare
@ -70,8 +73,8 @@ jobs:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Extract metadata for Docker
id: meta
@ -81,16 +84,15 @@ jobs:
- name: Build Docker image
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
uses: depot/build-push-action@v1
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=${{ matrix.service_name }}
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
- name: Export digest
env:
@ -108,9 +110,33 @@ jobs:
if-no-files-found: error
retention-days: 1
fork-build-validate:
if: github.repository != 'langgenius/dify'
runs-on: ubuntu-24.04
strategy:
matrix:
include:
- service_name: "validate-api-amd64"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "validate-web-amd64"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
- name: Validate Docker image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
with:
push: false
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: linux/amd64
create-manifest:
needs: build
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify'
strategy:
matrix:

View File

@ -9,7 +9,7 @@ concurrency:
jobs:
db-migration-test-postgres:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code
@ -59,7 +59,7 @@ jobs:
run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/dev'

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/enterprise'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl'

View File

@ -14,40 +14,69 @@ concurrency:
jobs:
build-docker:
if: github.event.pull_request.head.repo.full_name == github.repository
runs-on: ${{ matrix.runs_on }}
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
- service_name: "api-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "api-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "web/Dockerfile"
- service_name: "web-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Build Docker Image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
uses: depot/build-push-action@v1
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-docker-fork:
if: github.event.pull_request.head.repo.full_name != github.repository
runs-on: ubuntu-24.04
permissions:
contents: read
strategy:
matrix:
include:
- service_name: "api-amd64"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
- name: Build Docker Image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
with:
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: linux/amd64

View File

@ -7,7 +7,7 @@ jobs:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with:

View File

@ -23,7 +23,7 @@ concurrency:
jobs:
pre_job:
name: Skip Duplicate Checks
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
outputs:
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
steps:
@ -39,7 +39,7 @@ jobs:
name: Check Changed Files
needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
outputs:
api-changed: ${{ steps.changes.outputs.api }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
@ -141,7 +141,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped API tests
run: echo "No API-related changes detected; skipping API tests."
@ -154,7 +154,7 @@ jobs:
- check-changes
- api-tests-run
- api-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize API Tests status
env:
@ -201,7 +201,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped web tests
run: echo "No web-related changes detected; skipping web tests."
@ -214,7 +214,7 @@ jobs:
- check-changes
- web-tests-run
- web-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize Web Tests status
env:
@ -260,7 +260,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped web full-stack e2e
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
@ -273,7 +273,7 @@ jobs:
- check-changes
- web-e2e-run
- web-e2e-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize Web Full-Stack E2E status
env:
@ -325,7 +325,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped VDB tests
run: echo "No VDB-related changes detected; skipping VDB tests."
@ -338,7 +338,7 @@ jobs:
- check-changes
- vdb-tests-run
- vdb-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize VDB Tests status
env:
@ -384,7 +384,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped DB migration tests
run: echo "No migration-related changes detected; skipping DB migration tests."
@ -397,7 +397,7 @@ jobs:
- check-changes
- db-migration-test-run
- db-migration-test-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize DB Migration Test status
env:

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with pyrefly diff
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
actions: read
contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-diff:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
contents: read
issues: write

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
actions: read
contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
contents: read
issues: write

View File

@ -16,7 +16,7 @@ jobs:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'

View File

@ -12,7 +12,7 @@ on:
jobs:
stale:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
issues: write
pull-requests: write

View File

@ -15,7 +15,7 @@ permissions:
jobs:
python-style:
name: Python Style
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code
@ -57,7 +57,7 @@ jobs:
web-style:
name: Web Style
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:
working-directory: ./web
@ -131,7 +131,7 @@ jobs:
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code

View File

@ -18,7 +18,7 @@ concurrency:
jobs:
build:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:

View File

@ -35,7 +35,7 @@ concurrency:
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
timeout-minutes: 120
steps:
@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
timeout-minutes: 5
steps:

View File

@ -16,7 +16,7 @@ jobs:
test:
name: Full VDB Tests
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
strategy:
matrix:
python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: VDB Smoke Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
strategy:
matrix:
python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: Web Full-Stack E2E
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:
shell: bash

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
VITEST_COVERAGE_SCOPE: app-components
strategy:
@ -54,7 +54,7 @@ jobs:
name: Merge Test Reports
if: ${{ !cancelled() }}
needs: [test]
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
@ -92,7 +92,7 @@ jobs:
dify-ui-test:
name: dify-ui Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:

View File

@ -147,7 +147,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)

View File

@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

View File

@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
oauth_client_params = encrypt_system_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -1379,6 +1400,7 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
action: str

View File

@ -692,6 +692,32 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")

View File

@ -8,10 +8,10 @@ from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@staticmethod
def _ensure_console_recipient_type(form: Form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
raise NotFoundError("form not found")
@setup_required
@login_required
@account_initialization_required
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."

View File

@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword")
@ -70,6 +75,7 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
return "", 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
def _require_tag_binding_edit_permission() -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings")
class TagBindingCollectionApi(Resource):
"""Canonical collection resource for tag binding creation."""
@console_ns.doc("create_tag_binding")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
return _create_tag_bindings()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
return _remove_tag_binding()

View File

@ -23,9 +23,11 @@ from .app import (
conversation,
file,
file_preview,
human_input_form,
message,
site,
workflow,
workflow_events,
)
from .dataset import (
dataset,
@ -50,6 +52,7 @@ __all__ = [
"file",
"file_preview",
"hit_testing",
"human_input_form",
"index",
"message",
"metadata",
@ -58,6 +61,7 @@ __all__ = [
"segment",
"site",
"workflow",
"workflow_events",
]
api.add_namespace(service_api_ns)

View File

@ -0,0 +1,137 @@
"""
Service API human input form endpoints.
This module exposes app-token authenticated APIs for fetching and submitting
paused human input forms in workflow/chatflow runs.
"""
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
# Keep app-token callers scoped to the public web-form surface; internal HITL
# routes must continue to flow through console-only authentication.
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
raise NotFound("Form not found")
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token
def get(self, app_model: App, form_token: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form submitted successfully",
400: "Bad request - invalid submission data",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, form_token: str):
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
recipient_type = form.recipient_type
if recipient_type is None:
logger.warning("Recipient type is None for form, form_id=%s", form.id)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=end_user.id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -0,0 +1,142 @@
"""
Service API workflow resume event stream endpoints.
"""
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@service_api_ns.route("/workflow/<string:task_id>/events")
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(
params={
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": (
"Whether to replay from persisted state snapshot, "
'specify `"true"` to include a status snapshot of executed nodes'
),
"continue_on_pause": (
"Whether to keep the stream open across workflow_paused events,"
'specify `"true"` to keep the stream open for `workflow_paused` events.'
),
}
)
@service_api_ns.doc(
responses={
200: "SSE event stream",
401: "Unauthorized - invalid API token",
404: "Workflow run not found",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise NotWorkflowAppError()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFound("Workflow run not found")
if workflow_run.created_by != end_user.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

View File

@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,

View File

@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
) -> (
ChatbotAppBlockingResponse
| AdvancedChatPausedBlockingResponse
| Generator[ChatbotAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppBlockingResponse,
AdvancedChatPausedBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
StreamEvent,
)
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
paused_data = blocking_response.data.model_dump(mode="json")
return {
"event": StreamEvent.WORKFLOW_PAUSED.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
"workflow_run_id": blocking_response.data.workflow_run_id,
"data": paused_data,
}
response = {
"event": "message",
"event": StreamEvent.MESSAGE.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
return response

View File

@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowPauseStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
AdvancedChatPausedBlockingResponse,
Generator[ChatbotAppStreamResponse, None, None],
]:
"""
Process generate task pipeline.
:return:
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
"""
Process blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> AdvancedChatPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return AdvancedChatPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=paused_nodes,
reasons=reasons,
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -1,7 +1,9 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Union, cast
from pydantic import JsonValue
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
@classmethod
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
return cast(TBlockingResponse, response)
@classmethod
def convert(
@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, Any]] = {
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data: dict[str, Any] | None = None
data: dict[str, JsonValue] | None = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
)
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
},
)
responses: list[StreamResponse] = []

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
)
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
response: dict[str, Any] = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, JsonValue] = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,

View File

@ -1,6 +1,7 @@
from collections.abc import Callable, Generator, Mapping
from collections.abc import Callable, Generator, Iterable, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.task_entities import StreamEvent
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
@ -26,6 +27,7 @@ class MessageGenerator:
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
@ -33,4 +35,5 @@ class MessageGenerator:
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
terminal_events=terminal_events,
)

View File

@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -59,7 +59,7 @@ def stream_topic_events(
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
if terminal_events is None:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:

View File

@ -25,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -9,24 +9,29 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.model_dump()
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -42,12 +42,15 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
def process(
self,
) -> Union[
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
]:
"""
Process generate task pipeline.
:return:
@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
"""
To blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
data=WorkflowAppPausedBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
return WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
),
)
return response
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> WorkflowAppPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
created_at = int(runtime_state.start_at)
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=human_input_responses[-1].workflow_run_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
created_at=created_at,
finished_at=None,
paused_nodes=paused_nodes,
reasons=reasons,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:

View File

@ -1,12 +1,13 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
data: Data
class HumanInputRequiredPauseReasonPayload(BaseModel):
"""
Public pause-reason payload used by blocking responses when only
``human_input_required`` events are available.
"""
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@classmethod
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
return cls(
form_id=data.form_id,
node_id=data.node_id,
node_title=data.node_title,
form_content=data.form_content,
inputs=data.inputs,
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
return {
"event": self.event.value,
"task_id": self.task_id,
@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
"""
ChatbotAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
message_id: str
workflow_run_id: str
answer: str
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
status: WorkflowExecutionStatus
elapsed_time: float
total_tokens: int
total_steps: int
data: Data
class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
"""
WorkflowAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
workflow_id: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_at: int
finished_at: int | None
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(
self,
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@ -65,7 +84,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
),
enable_credentials_cache=True,
)
self.model_manager = model_manager
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@ -0,0 +1,41 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -13,8 +13,6 @@ from core.llm_generator.output_parser.rule_config_generator import RuleConfigGen
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
@ -217,8 +215,8 @@ class LLMGenerator:
else:
# Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = {
"max_tokens": DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE,
"max_tokens": 2560,
"temperature": 0.0,
}
stop = []

View File

@ -10,7 +10,14 @@ logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def __init__(self, instruction_prompt: str | None = None) -> None:
self._instruction_prompt = instruction_prompt or DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
@staticmethod
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
if not instruction_prompt or not instruction_prompt.strip():
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
def get_format_instructions(self) -> str:
return self._instruction_prompt

View File

@ -104,9 +104,6 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
DEFAULT_SUGGESTED_QUESTIONS_MAX_TOKENS = 256
DEFAULT_SUGGESTED_QUESTIONS_TEMPERATURE = 0.0
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
@ -36,11 +37,13 @@ class ModelInstance:
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager:
def __init__(self, provider_manager: ProviderManager):
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"])
data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(

View File

@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
topK=max_keywords_per_chunk or 10,
)
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)

View File

@ -551,6 +551,7 @@ class RetrievalService:
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
assert i.index_node_id
segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)

View File

@ -11,6 +11,7 @@ from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.enums import SegmentType
class DatasetDocumentStore:
@ -127,6 +128,7 @@ class DatasetDocumentStore:
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -137,7 +139,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)
@ -163,6 +165,7 @@ class DatasetDocumentStore:
)
# add new child chunks
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -173,7 +176,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)

View File

@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
stream=False, # pyright: ignore[reportArgumentType]
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()

View File

@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
class SystemEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Encrypt parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
return params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an OAuth encrypter instance.
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
_encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global OAuth encrypter instance.
Get the global encrypter instance.
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Encrypt parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
params: Parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -12,20 +12,16 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
from extensions.ext_database import db
from models.human_input import HumanInputFormRecipient, RecipientType
_FORM_TOKEN_PRIORITY = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def load_form_tokens_by_form_id(
form_ids: Sequence[str],
*,
session: Session | None = None,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
"""Load the preferred access token for each human input form."""
unique_form_ids = list(dict.fromkeys(form_ids))
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
return {}
if session is not None:
return _load_form_tokens_by_form_id(session, unique_form_ids)
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
with Session(bind=db.engine, expire_on_commit=False) as new_session:
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
tokens_by_form_id: dict[str, tuple[int, str]] = {}
def _load_form_tokens_by_form_id(
session: Session,
form_ids: Sequence[str],
*,
surface: HumanInputSurface | None = None,
) -> dict[str, str]:
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(stmt):
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
if priority is None or not recipient.access_token:
if not recipient.access_token:
continue
recipients_by_form_id.setdefault(recipient.form_id, []).append(
(recipient.recipient_type, recipient.access_token)
)
candidate = (priority, recipient.access_token)
current = tokens_by_form_id.get(recipient.form_id)
if current is None or candidate[0] < current[0]:
tokens_by_form_id[recipient.form_id] = candidate
tokens_by_form_id: dict[str, str] = {}
for form_id, recipients in recipients_by_form_id.items():
token = _get_surface_form_token(recipients, surface=surface)
if token is not None:
tokens_by_form_id[form_id] = token
return tokens_by_form_id
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
def _get_surface_form_token(
recipients: Sequence[tuple[RecipientType, str]],
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface == HumanInputSurface.SERVICE_API:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token
return get_preferred_form_token(recipients)

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from graphon.entities.pause_reason import PauseReasonType
from models.human_input import RecipientType
class HumanInputSurface(StrEnum):
SERVICE_API = "service_api"
CONSOLE = "console"
# Service API is intentionally narrower than other surfaces: app-token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
}
# A single HITL form can have multiple recipient records; this shared priority
# keeps every API surface consistent about which resume token to expose.
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
RecipientType.BACKSTAGE: 0,
RecipientType.CONSOLE: 1,
RecipientType.STANDALONE_WEB_APP: 2,
}
def is_recipient_type_allowed_for_surface(
recipient_type: RecipientType | None,
surface: HumanInputSurface,
) -> bool:
if recipient_type is None:
return False
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
def get_preferred_form_token(
recipients: Sequence[tuple[RecipientType, str]],
) -> str | None:
chosen_token: str | None = None
chosen_priority: int | None = None
for recipient_type, token in recipients:
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
if priority is None or not token:
continue
if chosen_priority is None or priority < chosen_priority:
chosen_priority = priority
chosen_token = token
return chosen_token
def enrich_human_input_pause_reasons(
reasons: Sequence[Mapping[str, Any]],
*,
form_tokens_by_form_id: Mapping[str, str],
expiration_times_by_form_id: Mapping[str, int],
) -> list[dict[str, Any]]:
enriched: list[dict[str, Any]] = []
for reason in reasons:
updated = dict(reason)
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_id = updated.get("form_id")
if isinstance(form_id, str):
updated["form_token"] = form_tokens_by_form_id.get(form_id)
expiration_time = expiration_times_by_form_id.get(form_id)
if expiration_time is not None:
updated["expiration_time"] = expiration_time
enriched.append(updated)
return enriched

View File

@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -1036,7 +1036,7 @@ class DocumentSegment(Base):
return attachment_list
class ChildChunk(Base):
class ChildChunk(TypeBase):
__tablename__ = "child_chunks"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -1046,29 +1046,42 @@ class ChildChunk(Base):
)
# initial fields
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
segment_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
content = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
DateTime,
nullable=False,
server_default=sa.func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
indexing_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255),
nullable=False,
server_default=sa.text("'automatic'"),
default=SegmentType.AUTOMATIC,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False)
@property
def dataset(self):

View File

@ -1867,15 +1867,18 @@ class MessageAnnotation(TypeBase):
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
StringUUID,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID)
question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -225,8 +225,10 @@ class TestSpanBuilder:
span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan)
assert span.name == "test-span"
assert span.context is not None
assert span.context.trace_id == 123
assert span.context.span_id == 456
assert span.parent is not None
assert span.parent.span_id == 789
assert span.resource == resource
assert span.attributes == {"attr1": "val1"}

View File

@ -64,12 +64,13 @@ class TestSpanData:
def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError):
SpanData(
trace_id=123,
# span_id missing
name="test_span",
start_time=1000,
end_time=2000,
SpanData.model_validate(
{
"trace_id": 123,
"name": "test_span",
"start_time": 1000,
"end_time": 2000,
}
)
def test_span_data_arbitrary_types_allowed(self):

View File

@ -2,12 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@ -44,7 +46,7 @@ class RecordingTraceClient:
self.endpoint = endpoint
self.added_spans: list[object] = []
def add_span(self, span) -> None:
def add_span(self, span: object) -> None:
self.added_spans.append(span)
def api_check(self) -> bool:
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags.SAMPLED,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(context)
def _make_trace_metadata(
trace_id: int = 1,
workflow_span_id: int = 2,
session_id: str = "s",
user_id: str = "u",
links: list[Link] | None = None,
) -> TraceMetadata:
return TraceMetadata(
trace_id=trace_id,
workflow_span_id=workflow_span_id,
session_id=session_id,
user_id=user_id,
links=[] if links is None else links,
)
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
return cast(RecordingTraceClient, trace_instance.trace_client)
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = {
"workflow_id": "workflow-id",
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1]
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
)
trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2
message_span, llm_span = trace_instance.trace_client.added_spans
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span, llm_span = spans
assert message_span.name == "message"
assert message_span.trace_id == 10
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
)
)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "my-tool"
assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool"
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = BuiltinNodeTypes.CODE
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "title"
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
trace_metadata = _make_trace_metadata(links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "my-tool"
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "retrieval"
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "llm"
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
# CASE 1: With message_id
trace_info = _make_workflow_trace_info(
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
)
trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2
message_span = trace_instance.trace_client.added_spans[0]
workflow_span = trace_instance.trace_client.added_spans[1]
client = _recording_trace_client(trace_instance)
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span = spans[0]
workflow_span = spans[1]
assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear()
client.added_spans.clear()
# CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@ -1,4 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
from dify_trace_aliyun.entities.semconv import (
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
def test_format_retrieval_documents():
# Not a list
assert format_retrieval_documents("not a list") == []
assert format_retrieval_documents(cast(list[object], "not a list")) == []
# Valid list
docs = [
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
def test_format_input_messages():
# Not a dict
assert format_input_messages(None) == serialize_json_data([])
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No prompts
assert format_input_messages({}) == serialize_json_data([])
@ -244,7 +246,7 @@ def test_format_input_messages():
def test_format_output_messages():
# Not a dict
assert format_output_messages(None) == serialize_json_data([])
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])

View File

@ -25,13 +25,13 @@ class TestAliyunConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
AliyunConfig.model_validate({})
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
AliyunConfig.model_validate({"license_key": "test_license"})
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""

View File

@ -1,4 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
@ -129,7 +130,7 @@ def test_set_span_status():
return "SilentErrorRepr"
span.reset_mock()
set_span_status(span, SilentError())
set_span_status(span, cast(Exception | str | None, SilentError()))
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"

View File

@ -28,13 +28,13 @@ class TestLangfuseConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
LangfuseConfig.model_validate({})
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
LangfuseConfig.model_validate({"public_key": "public"})
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
LangfuseConfig.model_validate({"secret_key": "secret"})
def test_host_validation_empty(self):
"""Test host validation with empty value"""

View File

@ -2,6 +2,7 @@
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_langfuse.config import LangfuseConfig
@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
assert trace._get_completion_start_time(start_time, None) is None
assert trace._get_completion_start_time(start_time, -1) is None
assert trace._get_completion_start_time(start_time, "invalid") is None
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None

View File

@ -21,13 +21,13 @@ class TestLangSmithConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
LangSmithConfig.model_validate({})
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
LangSmithConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
LangSmithConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@ -599,7 +599,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_instance.message_trace(_make_message_trace_info())
mock_tracing["start"].assert_called_once()
@ -609,7 +608,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(error="something broke")
trace_instance.message_trace(trace_info)
@ -620,7 +618,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
monkeypatch.setenv("FILES_URL", "http://files.test")
file_data = SimpleNamespace(url="path/to/file.png")
@ -638,7 +635,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
trace_instance.message_trace(trace_info)
@ -651,7 +647,6 @@ class TestMessageTrace:
end_user = MagicMock()
end_user.session_id = "session-xyz"
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
trace_info = _make_message_trace_info(
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
@ -664,7 +659,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(
metadata={"from_account_id": "acc-1"},

View File

@ -12,6 +12,7 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
return instance
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_trace)
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_span)
# ---------------------------------------------------------------------------
# _seed_to_uuid4
# ---------------------------------------------------------------------------
@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
def test_root_span_is_created(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
assert instance.add_span.called
assert _add_span_mock(instance).called
def test_root_span_id_matches_expected(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
expected = self._expected_root_span_id(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected
def test_root_span_has_no_parent(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["parent_span_id"] is None
def test_trace_name_is_workflow_trace(self):
@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_name_is_workflow_trace(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_has_workflow_tag(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert "workflow" in root_span_kwargs["tags"]
def test_node_execution_spans_are_parented_to_root(self):
@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
# call_args_list[0] = root span, [1] = node execution span
assert instance.add_span.call_count == 2
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
add_span = _add_span_mock(instance)
assert add_span.call_count == 2
node_span_kwargs = add_span.call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
def test_node_span_not_parented_to_workflow_app_log_id(self):
@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] != old_parent_id
def test_root_span_id_differs_from_trace_id(self):
@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
def test_root_span_uses_workflow_run_id_directly(self):
@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info)
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected_root_span_id
def test_root_span_id_differs_from_no_message_id_case(self):
@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import sys
import types
from types import SimpleNamespace
from typing import Any, TypedDict, cast
from unittest.mock import MagicMock
import pytest
@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
self.kwargs = kwargs
class PatchedCoreComponents(TypedDict):
span_exporter: MagicMock
span_processor: MagicMock
tracer: MagicMock
span: MagicMock
tracer_provider: MagicMock
logger: MagicMock
trace_api: Any
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
"""Drop fake metric modules into sys.modules so the client imports resolve."""
@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture(autouse=True)
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
span_exporter = MagicMock(name="span_exporter")
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
}
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
return SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
def _build_client() -> TencentTraceClient:
return TencentTraceClient(
service_name="service",
@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
def test_resolve_grpc_target_handles_errors() -> None:
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
@pytest.mark.parametrize(
@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
client.record_trace_duration(0.5)
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
logger.debug.assert_called()
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
ctx = _make_span_context(span_id=2)
span.get_span_context.return_value = ctx
data = SpanData(
trace_id=1,
@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
span.add_event.assert_called_once()
span.set_status.assert_called_once()
span.end.assert_called_once_with(end_time=20)
assert client.span_contexts[2] == "ctx"
assert client.span_contexts[2] == ctx
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.span_contexts[10] = "existing"
existing_context = _make_span_context(span_id=10)
client.span_contexts[10] = existing_context
span = patch_core_components["span"]
span.get_span_context.return_value = "child"
span.get_span_context.return_value = _make_span_context(span_id=11)
data = SpanData(
trace_id=1,
@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
client._create_and_export_span(data)
trace_api = patch_core_components["trace_api"]
trace_api.NonRecordingSpan.assert_called_once_with("existing")
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
trace_api.set_span_in_context.assert_called_once()
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
client.tracer.start_span.side_effect = RuntimeError("boom")
client._create_and_export_span(
@ -385,7 +407,7 @@ def test_get_project_url() -> None:
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span_processor = patch_core_components["span_processor"]
tracer_provider = patch_core_components["tracer_provider"]
@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
metric_reader.shutdown.assert_called_once()
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
meter_provider = meter_provider_instances[-1]
meter_provider.shutdown.side_effect = RuntimeError("boom")
assert client.metric_reader is not None
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
client.shutdown()
@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
assert client.metric_reader is None
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
logger.exception.assert_called_once()
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
data = SpanData.model_construct(
trace_id=1,
@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_llm_duration")
client.hist_llm_duration = hist_mock
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["foo"], str)
assert attrs["bar"] == 2
@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_trace_duration")
client.hist_trace_duration = hist_mock
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["meta"], str)
assert attrs["ok"] is True
@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
],
)
def test_record_methods_handle_exceptions(
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
) -> None:
client = _build_client()
hist_mock = MagicMock(name=attr_name)
@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
def test_metrics_initializes_grpc_metric_exporter() -> None:
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
assert isinstance(exporter, DummyGrpcMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert metric_reader.exporter.kwargs["insecure"] is False
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert exporter.kwargs["insecure"] is False
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
assert isinstance(exporter, DummyHttpMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
assert isinstance(exporter, DummyJsonMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in metric_reader.exporter.kwargs
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in exporter.kwargs
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
_ = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in metric_reader.exporter.kwargs
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in exporter.kwargs
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -31,13 +31,13 @@ class TestWeaveConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
WeaveConfig.model_validate({})
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
WeaveConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
WeaveConfig(project="project")
WeaveConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options)
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
from packaging import version
from pydantic import BaseModel, model_validator
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None:
# Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name)
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
return False
try:
milvus_version = self._client.get_server_version()
milvus_version_raw = self._client.get_server_version()
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version:
return True

View File

@ -3,7 +3,7 @@ import json
import logging
import re
import uuid
from typing import Any
from typing import Any, TypedDict
import jieba.posseg as pseg # type: ignore
import numpy
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False
class _OraclePoolParams(TypedDict, total=False):
user: str
password: str
dsn: str
min: int
max: int
increment: int
config_dir: str | None
wallet_location: str | None
wallet_password: str | None
class OracleVectorConfig(BaseModel):
user: str
password: str
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
return connection
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
"max": 5,
"increment": 1,
}
pool_params = _OraclePoolParams(
user=config.user,
password=config.password,
dsn=config.dsn,
min=1,
max=5,
increment=1,
)
if config.is_autonomous:
pool_params.update(
{
"config_dir": config.config_dir,
"wallet_location": config.wallet_location,
"wallet_password": config.wallet_password,
}
)
pool_params["config_dir"] = config.config_dir
pool_params["wallet_location"] = config.wallet_location
pool_params["wallet_password"] = config.wallet_password
return oracledb.create_pool(**pool_params)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):

View File

@ -9,6 +9,7 @@ dependencies = [
"boto3>=1.42.91",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask>=3.1.3,<4.0.0",
"flask-cors>=6.0.2",
"gevent>=26.4.0",
"gevent-websocket>=0.10.1",
@ -117,7 +118,7 @@ dev = [
"faker>=40.15.0",
"lxml-stubs>=0.5.1",
"basedpyright>=1.39.3",
"ruff>=0.15.11",
"ruff>=0.15.12",
"pytest>=9.0.3",
"pytest-benchmark>=5.2.3",
"pytest-cov>=7.1.0",
@ -144,7 +145,7 @@ dev = [
"types-pexpect>=4.9.0",
"types-protobuf>=7.34.1",
"types-psutil>=7.2.2",
"types-psycopg2>=2.9.21",
"types-psycopg2>=2.9.21.20260422",
"types-pygments>=2.20.0",
"types-pymysql>=1.1.0",
"types-python-dateutil>=2.9.0",
@ -157,9 +158,9 @@ dev = [
"types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0",
"boto3-stubs>=1.42.92",
"boto3-stubs>=1.42.96",
"types-jmespath>=1.1.0.20260408",
"hypothesis>=6.152.1",
"hypothesis>=6.152.3",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=2.0.0.20260408",
"types_setuptools>=82.0.0.20260408",
@ -169,7 +170,7 @@ dev = [
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
"mypy>=1.20.1",
"mypy>=1.20.2",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",

View File

@ -42,7 +42,7 @@ from libs.helper import convert_datetime_to_date
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from models.enums import WorkflowRunTriggeredFrom
from models.human_input import HumanInputForm
from models.human_input import HumanInputForm, HumanInputFormRecipient
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.entities.workflow_pause import WorkflowPauseEntity
@ -63,6 +63,7 @@ class _WorkflowRunError(Exception):
def _build_human_input_required_reason(
reason_model: WorkflowPauseReason,
form_model: HumanInputForm | None,
recipients: Sequence[HumanInputFormRecipient] = (),
) -> HumanInputRequired:
form_content = ""
inputs = []
@ -89,7 +90,7 @@ def _build_human_input_required_reason(
resolved_default_values = dict(definition.default_values)
node_title = definition.node_title or node_title
return HumanInputRequired(
reason = HumanInputRequired(
form_id=form_id,
form_content=form_content,
inputs=inputs,
@ -98,6 +99,7 @@ def _build_human_input_required_reason(
node_title=node_title,
resolved_default_values=resolved_default_values,
)
return reason
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
@ -804,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
for form in session.scalars(form_stmt).all():
form_models[form.id] = form
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {}
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
for recipient in session.scalars(recipient_stmt).all():
recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient)
pause_reasons: list[PauseReason] = []
for reason in pause_reason_models:
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
form_model = form_models.get(reason.form_id)
pause_reasons.append(_build_human_input_required_reason(reason, form_model))
pause_reasons.append(
_build_human_input_required_reason(
reason,
form_model,
recipients_by_form_id.get(reason.form_id, ()),
)
)
else:
pause_reasons.append(reason.to_entity())
return pause_reasons

View File

@ -133,7 +133,14 @@ class AppAnnotationService:
raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=None,
message_id=None,
content=answer,
question=question,
account_id=current_user.id,
)
db.session.add(annotation)
db.session.commit()

View File

@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType, unlimited
from enums.quota_type import QuotaType
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
@ -106,7 +107,7 @@ class AppGenerateService:
quota_charge = unlimited()
if dify_config.BILLING_ENABLED:
try:
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
except QuotaExceededError:
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
@ -116,6 +117,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
)
@ -162,6 +164,7 @@ class AppGenerateService:
invoke_from=invoke_from,
streaming=True,
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
@ -183,6 +186,10 @@ class AppGenerateService:
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
@ -194,6 +201,7 @@ class AppGenerateService:
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
pause_state_config=pause_config,
)
),
request_id=request_id,

View File

@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.quota_service import QuotaService, unlimited
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -88,7 +89,10 @@ class AsyncWorkflowService:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
# commit read only session before starting the billig rpc call
session.commit()
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
@ -131,9 +135,10 @@ class AsyncWorkflowService:
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume quota
# 7. Reserve quota (commit after successful dispatch)
quota_charge = unlimited()
try:
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
except QuotaExceededError as e:
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
@ -153,13 +158,18 @@ class AsyncWorkflowService:
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
try:
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict)
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict)
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
@ -295,13 +305,21 @@ class AsyncWorkflowService:
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
def _get_workflow(
workflow_service: WorkflowService,
app_model: App,
workflow_id: str | None = None,
session: Session | None = None,
) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
session: Reuse this SQLAlchemy session for the lookup when provided,
so the caller's explicit session bears the connection cost
instead of Flask's request-scoped ``db.session``.
Returns:
Workflow instance
@ -311,12 +329,12 @@ class AsyncWorkflowService:
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model)
workflow = workflow_service.get_published_workflow(app_model, session=session)
if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")

View File

@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class QuotaReserveResult(TypedDict):
reservation_id: str
available: int
reserved: int
class QuotaCommitResult(TypedDict):
available: int
reserved: int
refunded: int
class QuotaReleaseResult(TypedDict):
available: int
reserved: int
released: int
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict):
size: int
limit: int
@ -149,11 +193,63 @@ class BillingService:
@classmethod
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
"""Deprecated: Use get_quota_info instead."""
params = {"tenant_id": tenant_id}
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
) -> QuotaReserveResult:
"""Reserve quota before task execution."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"request_id": request_id,
"amount": amount,
}
if meta:
payload["meta"] = meta
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
@classmethod
def quota_commit(
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
) -> QuotaCommitResult:
"""Commit a reservation with actual consumption."""
payload: dict = {
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
"actual_amount": actual_amount,
}
if meta:
payload["meta"] = meta
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
@classmethod
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
"""Release a reservation (cancel, return frozen quota)."""
return _quota_release_adapter.validate_python(
cls._send_request(
"POST",
"/quota/release",
json={
"tenant_id": tenant_id,
"feature_key": feature_key,
"reservation_id": reservation_id,
},
)
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}

View File

@ -3748,6 +3748,7 @@ class SegmentService:
ChildChunk.segment_id == segment.id,
)
)
assert current_user.current_tenant_id
child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset.id,
@ -3758,7 +3759,7 @@ class SegmentService:
index_node_hash=index_node_hash,
content=content,
word_count=len(content),
type="customized",
type=SegmentType.CUSTOMIZED,
created_by=current_user.id,
)
db.session.add(child_chunk)
@ -3818,6 +3819,7 @@ class SegmentService:
if new_child_chunks_args:
child_chunk_count = len(child_chunks)
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
assert current_user.current_tenant_id
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(args.content)
child_chunk = ChildChunk(
@ -3830,7 +3832,7 @@ class SegmentService:
index_node_hash=index_node_hash,
content=args.content,
word_count=len(args.content),
type="customized",
type=SegmentType.CUSTOMIZED,
created_by=current_user.id,
)

View File

@ -5,6 +5,7 @@ import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService:
@classmethod
@ttl_cache(ttl=5)
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")

View File

@ -177,6 +177,7 @@ class SystemFeatureModel(BaseModel):
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
trial_models: list[str] = []
enable_creators_platform: bool = False
enable_trial_app: bool = False
enable_explore_banner: bool = False
@ -241,6 +242,9 @@ class FeatureService:
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
system_features.enable_creators_platform = True
return system_features
@classmethod
@ -286,7 +290,7 @@ class FeatureService:
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
features_usage_info = BillingService.get_quota_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]

View File

@ -0,0 +1,233 @@
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from enums.quota_type import QuotaType
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota reservation (Reserve phase).
Lifecycle:
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
try:
do_work()
charge.commit() # Confirm consumption
except:
charge.refund() # Release frozen quota
If neither commit() nor refund() is called, the billing system's
cleanup CronJob will auto-release the reservation within ~75 seconds.
"""
success: bool
charge_id: str | None # reservation_id
_quota_type: QuotaType
_tenant_id: str | None = None
_feature_key: str | None = None
_amount: int = 0
_committed: bool = field(default=False, repr=False)
def commit(self, actual_amount: int | None = None) -> None:
"""
Confirm the consumption with actual amount.
Args:
actual_amount: Actual amount consumed. Defaults to the reserved amount.
If less than reserved, the difference is refunded automatically.
"""
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
return
try:
from services.billing_service import BillingService
amount = actual_amount if actual_amount is not None else self._amount
BillingService.quota_commit(
tenant_id=self._tenant_id,
feature_key=self._feature_key,
reservation_id=self.charge_id,
actual_amount=amount,
)
self._committed = True
logger.debug(
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
self._quota_type,
self._tenant_id,
self.charge_id,
amount,
)
except Exception:
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
def refund(self) -> None:
"""
Release the reserved quota (cancel the charge).
Safe to call even if:
- charge failed or was disabled (charge_id is None)
- already committed (Release after Commit is a no-op)
- already refunded (idempotent)
This method guarantees no exceptions will be raised.
"""
if not self.charge_id or not self._tenant_id or not self._feature_key:
return
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
def unlimited() -> QuotaCharge:
from enums.quota_type import QuotaType
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
class QuotaService:
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
@staticmethod
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve + immediate Commit (one-shot mode).
The returned QuotaCharge supports .refund() which calls Release.
For two-phase usage (e.g. streaming), use reserve() directly.
"""
charge = QuotaService.reserve(quota_type, tenant_id, amount)
if charge.success and charge.charge_id:
charge.commit()
return charge
@staticmethod
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Reserve quota before task execution (Reserve phase only).
The caller MUST call charge.commit() after the task succeeds,
or charge.refund() if the task fails.
Raises:
QuotaExceededError: When quota is insufficient
"""
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to reserve must be greater than 0")
request_id = str(uuid.uuid4())
feature_key = quota_type.billing_key
try:
reserve_resp = BillingService.quota_reserve(
tenant_id=tenant_id,
feature_key=feature_key,
request_id=request_id,
amount=amount,
)
reservation_id = reserve_resp.get("reservation_id")
if not reservation_id:
logger.warning(
"Reserve returned no reservation_id for %s, feature %s, response: %s",
tenant_id,
quota_type.value,
reserve_resp,
)
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
logger.debug(
"Reserved %d %s quota for tenant %s, reservation_id: %s",
amount,
quota_type.value,
tenant_id,
reservation_id,
)
return QuotaCharge(
success=True,
charge_id=reservation_id,
_quota_type=quota_type,
_tenant_id=tenant_id,
_feature_key=feature_key,
_amount=amount,
)
except QuotaExceededError:
raise
except ValueError:
raise
except Exception:
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
return unlimited()
@staticmethod
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = QuotaService.get_remaining(quota_type, tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
return True
@staticmethod
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
"""Release a reservation. Guarantees no exceptions."""
try:
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not reservation_id:
return
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
BillingService.quota_release(
tenant_id=tenant_id,
feature_key=feature_key,
reservation_id=reservation_id,
)
except Exception:
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
@staticmethod
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
from services.billing_service import BillingService
try:
usage_info = BillingService.get_quota_info(tenant_id)
if isinstance(usage_info, dict):
feature_info = usage_info.get(quota_type.billing_key, {})
if isinstance(feature_info, dict):
limit = feature_info.get("limit", 0)
usage = feature_info.get("usage", 0)
if limit == -1:
return -1
return max(0, limit - usage)
return 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
return -1

View File

@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
@ -521,7 +521,7 @@ class BuiltinToolManageService:
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
@ -635,7 +635,7 @@ class TriggerProviderService:
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -38,6 +38,7 @@ from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.errors.app import QuotaExceededError
from services.quota_service import QuotaService
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
@ -798,45 +799,47 @@ class WebhookService:
Exception: If workflow execution fails
"""
try:
with Session(db.engine) as session:
# Prepare inputs for the webhook node
# The webhook node expects webhook_data in the inputs
workflow_inputs = cls.build_workflow_inputs(webhook_data)
workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Create trigger data
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id,
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
)
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
)
raise
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
# consume quota before triggering workflow execution
try:
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
try:
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
# trigger_workflow_async need to handle multipe session commits internally
with Session(db.engine, expire_on_commit=False) as session:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
raise
# Trigger workflow execution asynchronously
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -16,6 +16,7 @@ from graphon.model_runtime.entities.model_entities import ModelType
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.enums import SegmentType
logger = logging.getLogger(__name__)
@ -178,7 +179,7 @@ class VectorService:
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
@ -222,6 +223,7 @@ class VectorService:
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
assert update_child_chunk.index_node_id
child_document = Document(
page_content=update_child_chunk.content,
metadata={
@ -234,6 +236,7 @@ class VectorService:
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
assert delete_child_chunk.index_node_id
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
@ -246,6 +249,7 @@ class VectorService:
@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
assert child_chunk.index_node_id
vector.delete_by_ids([child_chunk.index_node_id])
@classmethod

Some files were not shown because too many files have changed in this diff Show More