mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 14:58:23 +08:00
Merge remote-tracking branch 'origin/main'
# Conflicts: # api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
This commit is contained in:
commit
3c27a90eb9
@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// vi.mock('@/next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
||||
8
.github/actions/setup-web/action.yml
vendored
8
.github/actions/setup-web/action.yml
vendored
@ -4,10 +4,10 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
|
||||
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
|
||||
with:
|
||||
node-version-file: "./web/.nvmrc"
|
||||
node-version-file: web/.nvmrc
|
||||
cache: true
|
||||
cache-dependency-path: web/pnpm-lock.yaml
|
||||
run-install: |
|
||||
- cwd: ./web
|
||||
args: ['--frozen-lockfile']
|
||||
cwd: ./web
|
||||
|
||||
2
.github/workflows/anti-slop.yml
vendored
2
.github/workflows/anti-slop.yml
vendored
@ -12,7 +12,7 @@ jobs:
|
||||
anti-slop:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: peakoss/anti-slop@v0
|
||||
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
|
||||
38
.github/workflows/api-tests.yml
vendored
38
.github/workflows/api-tests.yml
vendored
@ -2,6 +2,12 @@ name: Run Pytest
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -11,6 +17,8 @@ jobs:
|
||||
test:
|
||||
name: API Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@ -24,10 +32,11 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -79,21 +88,12 @@ jobs:
|
||||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
{
|
||||
echo ""
|
||||
echo "<details><summary>File-level coverage (click to expand)</summary>"
|
||||
echo ""
|
||||
echo '```'
|
||||
uv run --project api coverage report -m
|
||||
echo '```'
|
||||
echo "</details>"
|
||||
} >> $GITHUB_STEP_SUMMARY
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
flags: api
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
6
.github/workflows/main-ci.yml
vendored
6
.github/workflows/main-ci.yml
vendored
@ -56,16 +56,14 @@ jobs:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
with:
|
||||
base_sha: ${{ github.event.before || github.event.pull_request.base.sha }}
|
||||
diff_range_mode: ${{ github.event.before && 'exact' || 'merge-base' }}
|
||||
head_sha: ${{ github.event.after || github.event.pull_request.head.sha || github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
|
||||
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
65
.github/workflows/web-tests.yml
vendored
65
.github/workflows/web-tests.yml
vendored
@ -2,16 +2,9 @@ name: Web Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
base_sha:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
type: string
|
||||
diff_range_mode:
|
||||
required: false
|
||||
type: string
|
||||
head_sha:
|
||||
required: false
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@ -63,7 +56,7 @@ jobs:
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@ -87,52 +80,16 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
run: vp test --merge-reports --coverage --silent=passed-only
|
||||
|
||||
- name: Report app/components baseline coverage
|
||||
run: node ./scripts/report-components-coverage-baseline.mjs
|
||||
|
||||
- name: Report app/components test touch
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/report-components-test-touch.mjs
|
||||
|
||||
- name: Check app/components pure diff coverage
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Check Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
web-build:
|
||||
name: Web Build
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
@ -48,14 +50,15 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(ToolOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(ToolOAuthSystemClient).where(
|
||||
ToolOAuthSystemClient.provider == provider_name,
|
||||
ToolOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@ -97,14 +100,15 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(TriggerOAuthSystemClient).where(
|
||||
TriggerOAuthSystemClient.provider == provider_name,
|
||||
TriggerOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@ -139,14 +143,15 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(DatasourceOauthParamConfig).where(
|
||||
DatasourceOauthParamConfig.provider == provider_name,
|
||||
DatasourceOauthParamConfig.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@ -192,7 +197,9 @@ def transform_datasource_credentials(environment: str):
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
notion_credentials = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion")
|
||||
).all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for notion_credential in notion_credentials:
|
||||
@ -201,7 +208,7 @@ def transform_datasource_credentials(environment: str):
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
@ -250,7 +257,9 @@ def transform_datasource_credentials(environment: str):
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
firecrawl_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl")
|
||||
).all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
@ -259,7 +268,7 @@ def transform_datasource_credentials(environment: str):
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
@ -312,7 +321,9 @@ def transform_datasource_credentials(environment: str):
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
jina_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader")
|
||||
).all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for jina_credential in jina_credentials:
|
||||
@ -321,7 +332,7 @@ def transform_datasource_credentials(environment: str):
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@ -740,14 +743,17 @@ def migrate_oss(
|
||||
else:
|
||||
try:
|
||||
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
|
||||
updated = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
|
||||
)
|
||||
updated = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
update(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.values(storage_type=dify_config.STORAGE_TYPE)
|
||||
),
|
||||
).rowcount
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -41,7 +42,7 @@ def reset_encrypt_key_pair():
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
tenants = session.scalars(select(Tenant)).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
@ -49,8 +50,8 @@ def reset_encrypt_key_pair():
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
|
||||
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -93,7 +94,7 @@ def convert_to_agent_apps():
|
||||
app_id = str(i.id)
|
||||
if app_id not in proceeded_app_ids:
|
||||
proceeded_app_ids.append(app_id)
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
if app is not None:
|
||||
apps.append(app)
|
||||
|
||||
@ -108,8 +109,8 @@ def convert_to_agent_apps():
|
||||
db.session.commit()
|
||||
|
||||
# update conversation mode to agent
|
||||
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
|
||||
{Conversation.mode: AppMode.AGENT_CHAT}
|
||||
db.session.execute(
|
||||
update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
@ -177,7 +178,7 @@ where sites.id is null limit 1000"""
|
||||
continue
|
||||
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
if not app:
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
@ -41,14 +41,13 @@ def migrate_annotation_vector_database():
|
||||
# get apps info
|
||||
per_page = 50
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
apps = session.scalars(
|
||||
select(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@ -63,8 +62,8 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1)
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
@ -72,10 +71,10 @@ def migrate_annotation_vector_database():
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
dataset_collection_binding = session.scalar(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
@ -205,11 +204,11 @@ def migrate_knowledge_vector_database():
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
dataset_collection_binding = db.session.execute(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == dataset.collection_binding_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
@ -334,7 +333,7 @@ def add_qdrant_index(field: str):
|
||||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
bindings = db.session.scalars(select(DatasetCollectionBinding)).all()
|
||||
if not bindings:
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
@ -421,10 +420,10 @@ def old_metadata_migration():
|
||||
if field.value == key:
|
||||
break
|
||||
else:
|
||||
dataset_metadata = (
|
||||
db.session.query(DatasetMetadata)
|
||||
dataset_metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_metadata:
|
||||
dataset_metadata = DatasetMetadata(
|
||||
@ -436,7 +435,7 @@ def old_metadata_migration():
|
||||
)
|
||||
db.session.add(dataset_metadata)
|
||||
db.session.flush()
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
metadata_id=dataset_metadata.id,
|
||||
@ -445,14 +444,14 @@ def old_metadata_migration():
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
else:
|
||||
dataset_metadata_binding = (
|
||||
db.session.query(DatasetMetadataBinding) # type: ignore
|
||||
dataset_metadata_binding = db.session.scalar(
|
||||
select(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||
DatasetMetadataBinding.document_id == document.id,
|
||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_metadata_binding:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
|
||||
@ -103,13 +103,13 @@ class AppMCPServerController(Resource):
|
||||
raise NotFound()
|
||||
|
||||
description = payload.description
|
||||
if description is None:
|
||||
pass
|
||||
elif not description:
|
||||
if description is None or not description:
|
||||
server.description = app_model.description or ""
|
||||
else:
|
||||
server.description = description
|
||||
|
||||
server.name = app_model.name
|
||||
|
||||
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||
if payload.status:
|
||||
try:
|
||||
|
||||
@ -24,6 +24,7 @@ from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from flask_restx import Resource
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.enums import BannerStatus
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
@ -16,7 +17,7 @@ class BannerApi(Resource):
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
|
||||
@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
|
||||
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook_debug(webhook_id: str):
|
||||
"""Handle webhook debug calls without triggering production workflow execution."""
|
||||
"""Handle webhook debug calls without triggering production workflow execution.
|
||||
|
||||
The debug webhook endpoint is only for draft inspection flows. It never enqueues
|
||||
Celery work for the published workflow; instead it dispatches an in-memory debug
|
||||
event to an active Variable Inspector listener. Returning a clear error when no
|
||||
listener is registered prevents a misleading 200 response for requests that are
|
||||
effectively dropped.
|
||||
"""
|
||||
try:
|
||||
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
|
||||
if error:
|
||||
@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
|
||||
"method": webhook_data.get("method"),
|
||||
},
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
dispatch_count = TriggerDebugEventBus.dispatch(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
event=event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
if dispatch_count == 0:
|
||||
logger.warning(
|
||||
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
|
||||
webhook_trigger.webhook_id,
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.app_id,
|
||||
webhook_trigger.node_id,
|
||||
)
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": "No active debug listener",
|
||||
"message": (
|
||||
"The webhook debug URL only works while the Variable Inspector is listening. "
|
||||
"Use the published webhook URL to execute the workflow in Celery."
|
||||
),
|
||||
"execution_url": webhook_trigger.webhook_url,
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
@ -81,7 +82,7 @@ class DatasourceFileManager:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=filepath,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
|
||||
@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel):
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
try:
|
||||
@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="custom_model",
|
||||
credential_source=CredentialSourceType.CUSTOM_MODEL,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
|
||||
@ -1421,12 +1422,12 @@ class ProviderConfiguration(BaseModel):
|
||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
preferred_model_provider.preferred_provider_type = provider_type
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value,
|
||||
preferred_provider_type=provider_type,
|
||||
)
|
||||
s.add(preferred_model_provider)
|
||||
s.commit()
|
||||
@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "custom_model"
|
||||
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
|
||||
custom_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "provider"
|
||||
if config.credential_source_type != CredentialSourceType.PROVIDER
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
|
||||
@ -195,7 +195,7 @@ class ProviderManager:
|
||||
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
preferred_provider_type = preferred_provider_type_record.preferred_provider_type
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
|
||||
@ -68,9 +68,12 @@ class SegmentRecord(TypedDict):
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod | str
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@ -1,12 +1,38 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any, cast
|
||||
from typing import Any, NotRequired, cast
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class FirecrawlDocumentData(TypedDict):
|
||||
title: str | None
|
||||
description: str | None
|
||||
source_url: str | None
|
||||
markdown: str | None
|
||||
|
||||
|
||||
class CrawlStatusResponse(TypedDict):
|
||||
status: str
|
||||
total: int | None
|
||||
current: int | None
|
||||
data: list[FirecrawlDocumentData]
|
||||
|
||||
|
||||
class MapResponse(TypedDict):
|
||||
success: bool
|
||||
links: list[str]
|
||||
|
||||
|
||||
class SearchResponse(TypedDict):
|
||||
success: bool
|
||||
data: list[dict[str, Any]]
|
||||
warning: NotRequired[str]
|
||||
|
||||
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
@ -14,7 +40,7 @@ class FirecrawlApp:
|
||||
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict[str, Any]:
|
||||
def scrape_url(self, url, params=None) -> FirecrawlDocumentData:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
@ -32,9 +58,7 @@ class FirecrawlApp:
|
||||
return self._extract_common_fields(data)
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "scrape URL")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
|
||||
@ -51,7 +75,7 @@ class FirecrawlApp:
|
||||
self._handle_error(response, "start crawl job")
|
||||
return "" # unreachable
|
||||
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> MapResponse:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
|
||||
headers = self._prepare_headers()
|
||||
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
|
||||
@ -60,14 +84,12 @@ class FirecrawlApp:
|
||||
json_data.update(params)
|
||||
response = self._post_request(self._build_url("v2/map"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
return cast(dict[str, Any], response.json())
|
||||
return cast(MapResponse, response.json())
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "start map job")
|
||||
return {}
|
||||
else:
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
def check_crawl_status(self, job_id) -> CrawlStatusResponse:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
|
||||
if response.status_code == 200:
|
||||
@ -77,7 +99,7 @@ class FirecrawlApp:
|
||||
if total == 0:
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list = []
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = self._extract_common_fields(item)
|
||||
@ -95,13 +117,15 @@ class FirecrawlApp:
|
||||
return self._format_crawl_status_response(
|
||||
crawl_status_response.get("status"), crawl_status_response, []
|
||||
)
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
return {} # unreachable
|
||||
self._handle_error(response, "check crawl status")
|
||||
raise RuntimeError("unreachable: _handle_error always raises")
|
||||
|
||||
def _format_crawl_status_response(
|
||||
self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
self,
|
||||
status: str,
|
||||
crawl_status_response: dict[str, Any],
|
||||
url_data_list: list[FirecrawlDocumentData],
|
||||
) -> CrawlStatusResponse:
|
||||
return {
|
||||
"status": status,
|
||||
"total": crawl_status_response.get("total"),
|
||||
@ -109,7 +133,7 @@ class FirecrawlApp:
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]:
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> FirecrawlDocumentData:
|
||||
return {
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
@ -117,7 +141,7 @@ class FirecrawlApp:
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
|
||||
def _prepare_headers(self) -> dict[str, Any]:
|
||||
def _prepare_headers(self) -> dict[str, str]:
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
@ -150,10 +174,10 @@ class FirecrawlApp:
|
||||
error_message = response.text or "Unknown error occurred"
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
||||
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> SearchResponse:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
json_data: dict[str, Any] = {
|
||||
"query": query,
|
||||
"limit": 5,
|
||||
"lang": "en",
|
||||
@ -170,12 +194,10 @@ class FirecrawlApp:
|
||||
json_data.update(params)
|
||||
response = self._post_request(self._build_url("v2/search"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
response_data: SearchResponse = response.json()
|
||||
if not response_data.get("success"):
|
||||
raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
|
||||
return cast(dict[str, Any], response_data)
|
||||
return response_data
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "perform search")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
|
||||
|
||||
@ -15,6 +15,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -150,7 +151,7 @@ class PdfExtractor(BaseExtractor):
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=len(img_bytes),
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlAuthenticationError,
|
||||
@ -13,6 +14,27 @@ from core.rag.extractor.watercrawl.exceptions import (
|
||||
)
|
||||
|
||||
|
||||
class SpiderOptions(TypedDict):
|
||||
max_depth: int
|
||||
page_limit: int
|
||||
allowed_domains: list[str]
|
||||
exclude_paths: list[str]
|
||||
include_paths: list[str]
|
||||
|
||||
|
||||
class PageOptions(TypedDict):
|
||||
exclude_tags: list[str]
|
||||
include_tags: list[str]
|
||||
wait_time: int
|
||||
include_html: bool
|
||||
only_main_content: bool
|
||||
include_links: bool
|
||||
timeout: int
|
||||
accept_cookies_selector: str
|
||||
locale: str
|
||||
actions: list[Any]
|
||||
|
||||
|
||||
class BaseAPIClient:
|
||||
def __init__(self, api_key, base_url):
|
||||
self.api_key = api_key
|
||||
@ -121,9 +143,9 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
def create_crawl_request(
|
||||
self,
|
||||
url: Union[list, str] | None = None,
|
||||
spider_options: dict | None = None,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
spider_options: SpiderOptions | None = None,
|
||||
page_options: PageOptions | None = None,
|
||||
plugin_options: dict[str, Any] | None = None,
|
||||
):
|
||||
data = {
|
||||
# 'urls': url if isinstance(url, list) else [url],
|
||||
@ -176,8 +198,8 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
def scrape_url(
|
||||
self,
|
||||
url: str,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
page_options: PageOptions | None = None,
|
||||
plugin_options: dict[str, Any] | None = None,
|
||||
sync: bool = True,
|
||||
prefetched: bool = True,
|
||||
):
|
||||
|
||||
@ -2,16 +2,39 @@ from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient
|
||||
|
||||
|
||||
class WatercrawlDocumentData(TypedDict):
|
||||
title: str | None
|
||||
description: str | None
|
||||
source_url: str | None
|
||||
markdown: str | None
|
||||
|
||||
|
||||
class CrawlJobResponse(TypedDict):
|
||||
status: str
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class WatercrawlCrawlStatusResponse(TypedDict):
|
||||
status: str
|
||||
job_id: str | None
|
||||
total: int
|
||||
current: int
|
||||
data: list[WatercrawlDocumentData]
|
||||
time_consuming: float
|
||||
|
||||
|
||||
class WaterCrawlProvider:
|
||||
def __init__(self, api_key, base_url: str | None = None):
|
||||
self.client = WaterCrawlAPIClient(api_key, base_url)
|
||||
|
||||
def crawl_url(self, url, options: dict | Any | None = None):
|
||||
def crawl_url(self, url: str, options: dict[str, Any] | None = None) -> CrawlJobResponse:
|
||||
options = options or {}
|
||||
spider_options = {
|
||||
spider_options: SpiderOptions = {
|
||||
"max_depth": 1,
|
||||
"page_limit": 1,
|
||||
"allowed_domains": [],
|
||||
@ -25,7 +48,7 @@ class WaterCrawlProvider:
|
||||
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
|
||||
|
||||
wait_time = options.get("wait_time", 1000)
|
||||
page_options = {
|
||||
page_options: PageOptions = {
|
||||
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
|
||||
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
|
||||
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
|
||||
@ -41,9 +64,9 @@ class WaterCrawlProvider:
|
||||
|
||||
return {"status": "active", "job_id": result.get("uuid")}
|
||||
|
||||
def get_crawl_status(self, crawl_request_id):
|
||||
def get_crawl_status(self, crawl_request_id: str) -> WatercrawlCrawlStatusResponse:
|
||||
response = self.client.get_crawl_request(crawl_request_id)
|
||||
data = []
|
||||
data: list[WatercrawlDocumentData] = []
|
||||
if response["status"] in ["new", "running"]:
|
||||
status = "active"
|
||||
else:
|
||||
@ -67,7 +90,7 @@ class WaterCrawlProvider:
|
||||
"time_consuming": time_consuming,
|
||||
}
|
||||
|
||||
def get_crawl_url_data(self, job_id, url) -> dict | None:
|
||||
def get_crawl_url_data(self, job_id: str, url: str) -> WatercrawlDocumentData | None:
|
||||
if not job_id:
|
||||
return self.scrape_url(url)
|
||||
|
||||
@ -82,11 +105,11 @@ class WaterCrawlProvider:
|
||||
|
||||
return None
|
||||
|
||||
def scrape_url(self, url: str):
|
||||
def scrape_url(self, url: str) -> WatercrawlDocumentData:
|
||||
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
|
||||
return self._structure_data(response)
|
||||
|
||||
def _structure_data(self, result_object: dict):
|
||||
def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData:
|
||||
if isinstance(result_object.get("result", {}), str):
|
||||
raise ValueError("Invalid result object. Expected a dictionary.")
|
||||
|
||||
@ -98,7 +121,9 @@ class WaterCrawlProvider:
|
||||
"markdown": result_object.get("result", {}).get("markdown"),
|
||||
}
|
||||
|
||||
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
|
||||
def _get_results(
|
||||
self, crawl_request_id: str, query_params: dict | None = None
|
||||
) -> Generator[WatercrawlDocumentData, None, None]:
|
||||
page = 0
|
||||
page_size = 100
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -112,7 +113,7 @@ class WordExtractor(BaseExtractor):
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
@ -140,7 +141,7 @@ class WordExtractor(BaseExtractor):
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
|
||||
@ -9,6 +9,7 @@ from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -51,7 +52,7 @@ class IndexProcessor:
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
@ -131,7 +132,12 @@ class IndexProcessor:
|
||||
}
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
self,
|
||||
chunks: Any,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_structure: str,
|
||||
summary_index_setting: SummaryIndexSettingDict | None,
|
||||
) -> Preview:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
|
||||
@ -7,10 +7,11 @@ import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
@ -36,6 +37,13 @@ if TYPE_CHECKING:
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class SummaryIndexSettingDict(TypedDict):
|
||||
enable: bool
|
||||
model_name: NotRequired[str]
|
||||
model_provider_name: NotRequired[str]
|
||||
summary_prompt: NotRequired[str]
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files."""
|
||||
|
||||
@ -52,7 +60,7 @@ class BaseIndexProcessor(ABC):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
||||
@ -23,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
@ -279,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
@ -363,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
segment_id: str | None = None,
|
||||
document_language: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
|
||||
@ -19,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
@ -362,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
||||
@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
@ -245,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
||||
@ -33,7 +33,7 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
@ -87,7 +87,7 @@ from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -666,7 +666,11 @@ class DatasetRetrieval:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
return []
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model_config: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
@ -1058,7 +1062,11 @@ class DatasetRetrieval:
|
||||
all_documents.append(document)
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
@ -1132,7 +1140,7 @@ class DatasetRetrieval:
|
||||
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -1141,7 +1149,11 @@ class DatasetRetrieval:
|
||||
}
|
||||
|
||||
for dataset in available_datasets:
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model_config: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
|
||||
@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
@ -11,7 +12,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class SummaryIndex:
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
is_preview: bool,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Final
|
||||
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
|
||||
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
|
||||
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
|
||||
|
||||
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
is_preview: bool,
|
||||
batch: Any,
|
||||
chunks: Mapping[str, Any],
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("document_id is required.")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
|
||||
@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
|
||||
# Get trigger data passed when workflow was triggered
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"event_name": self.node_data.event_name,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
|
||||
@ -245,6 +245,9 @@ _END_STATE = frozenset(
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
|
||||
Values in this enum are persisted as execution metadata and must stay in sync
|
||||
with every node that writes `NodeRunResult.metadata`.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@ -256,9 +256,13 @@ def fetch_prompt_messages(
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if prompt_message_content:
|
||||
if not prompt_message_content:
|
||||
continue
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
@ -24,13 +25,11 @@ def handle(sender, **kwargs):
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
document = db.session.scalar(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not document:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
|
||||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
|
||||
)
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
|
||||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
|
||||
)
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
|
||||
@ -3,6 +3,7 @@ import json
|
||||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -34,16 +35,15 @@ def load_user_from_request(request_from_flask_login):
|
||||
if admin_api_key and admin_api_key == auth_token:
|
||||
workspace_id = request.headers.get("X-WORKSPACE-ID")
|
||||
if workspace_id:
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == workspace_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role == "owner")
|
||||
.one_or_none()
|
||||
)
|
||||
).one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).filter_by(id=ta.account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == ta.account_id))
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if not end_user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
@ -80,7 +80,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
decoded = PassportService().verify(auth_token)
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
@ -90,11 +90,11 @@ def load_user_from_request(request_from_flask_login):
|
||||
server_code = request.view_args.get("server_code") if request.view_args else None
|
||||
if not server_code:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
app_mcp_server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
|
||||
if not app_mcp_server:
|
||||
raise NotFound("App MCP server not found.")
|
||||
end_user = (
|
||||
db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first()
|
||||
end_user = db.session.scalar(
|
||||
select(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").limit(1)
|
||||
)
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
|
||||
@ -32,7 +32,7 @@ class OpenDALStorage(BaseStorage):
|
||||
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
|
||||
|
||||
if scheme == "fs":
|
||||
root = kwargs.get("root", "storage")
|
||||
root = kwargs.setdefault("root", "storage")
|
||||
Path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
|
||||
|
||||
@ -424,13 +424,11 @@ def _build_from_datasource_file(
|
||||
datasource_file_id = mapping.get("datasource_file_id")
|
||||
if not datasource_file_id:
|
||||
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
|
||||
datasource_file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
datasource_file = db.session.scalar(
|
||||
select(UploadFile).where(
|
||||
UploadFile.id == datasource_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if datasource_file is None:
|
||||
|
||||
@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
||||
@ -23,13 +23,22 @@ from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus
|
||||
from .enums import (
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
MessageChainType,
|
||||
MessageStatus,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
@ -925,8 +934,11 @@ class ExporleBanner(TypeBase):
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
||||
status: Mapped[BannerStatus] = mapped_column(
|
||||
EnumText(BannerStatus, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'enabled'::character varying"),
|
||||
default=BannerStatus.ENABLED,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
@ -2097,7 +2109,7 @@ class UploadFile(Base):
|
||||
# The `server_default` serves as a fallback mechanism.
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
@ -2141,7 +2153,7 @@ class UploadFile(Base):
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
storage_type: str,
|
||||
storage_type: StorageType,
|
||||
key: str,
|
||||
name: str,
|
||||
size: int,
|
||||
@ -2206,7 +2218,7 @@ class MessageChain(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
|
||||
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -13,6 +13,7 @@ from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
@ -209,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase):
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
@ -237,7 +238,9 @@ class ProviderOrder(TypeBase):
|
||||
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[str | None] = mapped_column(String(40))
|
||||
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
|
||||
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
|
||||
payment_status: Mapped[PaymentStatus] = mapped_column(
|
||||
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
|
||||
)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
@ -300,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
|
||||
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
|
||||
EnumText(CredentialSourceType, length=40), nullable=True, default=None
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
|
||||
@ -22,14 +22,14 @@ from sqlalchemy import (
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
@ -936,8 +936,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
|
||||
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
|
||||
elif (
|
||||
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
|
||||
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
|
||||
):
|
||||
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.1"
|
||||
version = "1.13.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[pytest]
|
||||
pythonpath = .
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib --cov-branch --cov-report=xml
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||
|
||||
@ -3,6 +3,7 @@ import math
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
import app
|
||||
from core.helper.marketplace import fetch_global_plugin_manifest
|
||||
@ -28,17 +29,15 @@ def check_upgradable_plugin_task():
|
||||
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
|
||||
click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green"))
|
||||
|
||||
strategies = (
|
||||
db.session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(
|
||||
strategies = db.session.scalars(
|
||||
select(TenantPluginAutoUpgradeStrategy).where(
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
|
||||
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
|
||||
TenantPluginAutoUpgradeStrategy.strategy_setting
|
||||
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
total_strategies = len(strategies)
|
||||
click.echo(click.style(f"Total strategies: {total_strategies}", fg="green"))
|
||||
|
||||
@ -2,7 +2,7 @@ import datetime
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
@ -19,14 +19,12 @@ def clean_embedding_cache_task():
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
while True:
|
||||
try:
|
||||
embedding_ids = (
|
||||
db.session.query(Embedding.id)
|
||||
embedding_ids = db.session.scalars(
|
||||
select(Embedding.id)
|
||||
.where(Embedding.created_at < thirty_days_ago)
|
||||
.order_by(Embedding.created_at.desc())
|
||||
.limit(100)
|
||||
.all()
|
||||
)
|
||||
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
|
||||
).all()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if embedding_ids:
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
from typing import TypedDict
|
||||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
@ -51,7 +51,7 @@ def clean_unused_datasets_task():
|
||||
try:
|
||||
# Subquery for counting new documents
|
||||
document_subquery_new = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
select(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
@ -64,7 +64,7 @@ def clean_unused_datasets_task():
|
||||
|
||||
# Subquery for counting old documents
|
||||
document_subquery_old = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
select(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
@ -142,8 +142,8 @@ def clean_unused_datasets_task():
|
||||
index_processor.clean(dataset, None)
|
||||
|
||||
# Update document
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
|
||||
{Document.enabled: False}
|
||||
db.session.execute(
|
||||
update(Document).where(Document.dataset_id == dataset.id).values(enabled=False)
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
@ -20,7 +21,7 @@ def create_tidb_serverless_task():
|
||||
try:
|
||||
# check the number of idle tidb serverless
|
||||
idle_tidb_serverless_number = (
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
|
||||
db.session.scalar(select(func.count(TidbAuthBinding.id)).where(TidbAuthBinding.active == False)) or 0
|
||||
)
|
||||
if idle_tidb_serverless_number >= tidb_serverless_number:
|
||||
break
|
||||
|
||||
@ -49,16 +49,18 @@ def mail_clean_document_notify_task():
|
||||
if plan != CloudPlan.SANDBOX:
|
||||
knowledge_details = []
|
||||
# check tenant
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
# check current owner
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
)
|
||||
if not current_owner_join:
|
||||
continue
|
||||
account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
|
||||
if not account:
|
||||
continue
|
||||
|
||||
@ -71,7 +73,7 @@ def mail_clean_document_notify_task():
|
||||
)
|
||||
|
||||
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
|
||||
if dataset:
|
||||
document_count = len(document_ids)
|
||||
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
||||
|
||||
@ -23,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account
|
||||
@ -93,7 +94,7 @@ class FileService:
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_tenant_id or "",
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=filename,
|
||||
size=file_size,
|
||||
@ -152,7 +153,7 @@ class FileService:
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=text_name,
|
||||
size=len(text),
|
||||
|
||||
@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -103,9 +104,9 @@ class ModelLoadBalancingService:
|
||||
is_load_balancing_enabled = True
|
||||
|
||||
if config_from == "predefined-model":
|
||||
credential_source_type = "provider"
|
||||
credential_source_type = CredentialSourceType.PROVIDER
|
||||
else:
|
||||
credential_source_type = "custom_model"
|
||||
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
@ -421,7 +422,11 @@ class ModelLoadBalancingService:
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
if credential_id:
|
||||
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
|
||||
credential_source = (
|
||||
CredentialSourceType.PROVIDER
|
||||
if config_from == "predefined-model"
|
||||
else CredentialSourceType.CUSTOM_MODEL
|
||||
)
|
||||
assert credential_record is not None
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -49,7 +49,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"fetch pipeline template detail failed,"
|
||||
"fetch pipeline template detail failed,"
|
||||
+ f" status_code: {response.status_code},"
|
||||
+ f" response: {response.text[:1000]}"
|
||||
)
|
||||
|
||||
@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
@ -30,7 +31,7 @@ class SummaryIndexService:
|
||||
def generate_summary_for_segment(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for a single segment.
|
||||
@ -600,7 +601,7 @@ class SummaryIndexService:
|
||||
def generate_and_vectorize_summary(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Generate summary for a segment and vectorize it.
|
||||
@ -705,7 +706,7 @@ class SummaryIndexService:
|
||||
def generate_summaries_for_document(
|
||||
dataset: Dataset,
|
||||
document: DatasetDocument,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
segment_ids: list[str] | None = None,
|
||||
only_parent_chunks: bool = False,
|
||||
) -> list[DocumentSegmentSummary]:
|
||||
|
||||
@ -9,7 +9,7 @@ import httpx
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData
|
||||
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -216,8 +216,10 @@ class WebsiteService:
|
||||
"max_depth": request.options.max_depth,
|
||||
"use_sitemap": request.options.use_sitemap,
|
||||
}
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
return dict(
|
||||
WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -270,13 +272,13 @@ class WebsiteService:
|
||||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
"status": result.get("status", "active"),
|
||||
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data: dict[str, Any] = {
|
||||
"status": result["status"],
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
"total": result["total"] or 0,
|
||||
"current": result["current"] or 0,
|
||||
"data": result["data"],
|
||||
}
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
@ -289,8 +291,8 @@ class WebsiteService:
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
@ -343,7 +345,7 @@ class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
crawl_data: list[FirecrawlDocumentData] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
stored_data = storage.load_once(file_key)
|
||||
@ -352,19 +354,22 @@ class WebsiteService:
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
if result["status"] != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
crawl_data = result.get("data")
|
||||
crawl_data = result["data"]
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
if item["source_url"] == url:
|
||||
return dict(item)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
def _get_watercrawl_url_data(
|
||||
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
result = WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
return dict(result) if result is not None else None
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
@ -416,8 +421,8 @@ class WebsiteService:
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
params = {"onlyMainContent": request.only_main_content}
|
||||
return firecrawl_app.scrape_url(url=request.url, params=params)
|
||||
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url))
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
@ -288,7 +289,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
|
||||
@ -13,6 +13,7 @@ from dify_graph.variables.types import SegmentType
|
||||
from dify_graph.variables.variables import StringVariable
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models.enums import CreatorUserRole
|
||||
@ -347,7 +348,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Create an upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_offload_{uuid.uuid4()}.json",
|
||||
name="test_offload.json",
|
||||
size=len(content_bytes),
|
||||
@ -450,7 +451,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Create upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_integration_{uuid.uuid4()}.txt",
|
||||
name="test_integration.txt",
|
||||
size=len(content_bytes),
|
||||
|
||||
@ -6,6 +6,7 @@ from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
@ -197,7 +198,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
@ -210,7 +211,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
@ -430,7 +431,7 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
@ -443,7 +444,7 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
@ -289,7 +290,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
|
||||
@ -13,6 +13,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
@ -198,7 +199,7 @@ class DocumentStatusTestDataFactory:
|
||||
"""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
|
||||
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
|
||||
@ -83,7 +84,7 @@ def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, n
|
||||
"""Persist an upload file row referenced by document.data_source_info."""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
|
||||
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
@ -140,7 +141,7 @@ class TestFileService:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"upload_files/test/{fake.uuid4()}.txt",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
|
||||
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import DataSourceType
|
||||
from models.enums import DataSourceType, MessageChainType
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@ -236,7 +236,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
# MessageChain
|
||||
chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="system",
|
||||
type=MessageChainType.SYSTEM,
|
||||
input=json.dumps({"test": "input"}),
|
||||
output=json.dumps({"test": "output"}),
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -209,7 +210,7 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=account.current_tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
|
||||
@ -19,6 +19,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
@ -203,7 +204,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
|
||||
@ -18,6 +18,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -254,7 +255,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
@ -925,7 +926,7 @@ class TestCleanDatasetTask:
|
||||
special_filename = f"test_file_{special_content}.txt"
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{special_filename}",
|
||||
name=special_filename,
|
||||
size=1024,
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
@ -78,7 +79,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id:
|
||||
for i in range(count):
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test/file-{uuid.uuid4()}-{i}.json",
|
||||
name=f"file-{i}.json",
|
||||
size=1024 + i,
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
|
||||
|
||||
class TestOpenDALFsDefaultRoot:
|
||||
"""Test that OpenDALStorage with scheme='fs' works correctly when no root is provided."""
|
||||
|
||||
def test_fs_without_root_uses_default(self, tmp_path, monkeypatch):
|
||||
"""When no root is specified, the default 'storage' should be used and passed to the Operator."""
|
||||
# Change to tmp_path so the default "storage" dir is created there
|
||||
monkeypatch.chdir(tmp_path)
|
||||
# Ensure no OPENDAL_FS_ROOT env var is set
|
||||
monkeypatch.delenv("OPENDAL_FS_ROOT", raising=False)
|
||||
|
||||
storage = OpenDALStorage(scheme="fs")
|
||||
|
||||
# The default directory should have been created
|
||||
assert (tmp_path / "storage").is_dir()
|
||||
# The storage should be functional
|
||||
storage.save("test_default_root.txt", b"hello")
|
||||
assert storage.exists("test_default_root.txt")
|
||||
assert storage.load_once("test_default_root.txt") == b"hello"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_default_root.txt")
|
||||
|
||||
def test_fs_with_explicit_root(self, tmp_path):
|
||||
"""When root is explicitly provided, it should be used."""
|
||||
custom_root = str(tmp_path / "custom_storage")
|
||||
storage = OpenDALStorage(scheme="fs", root=custom_root)
|
||||
|
||||
assert Path(custom_root).is_dir()
|
||||
storage.save("test_explicit_root.txt", b"world")
|
||||
assert storage.exists("test_explicit_root.txt")
|
||||
assert storage.load_once("test_explicit_root.txt") == b"world"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_explicit_root.txt")
|
||||
|
||||
def test_fs_with_env_var_root(self, tmp_path, monkeypatch):
|
||||
"""When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs."""
|
||||
env_root = str(tmp_path / "env_storage")
|
||||
monkeypatch.setenv("OPENDAL_FS_ROOT", env_root)
|
||||
# Ensure .env file doesn't interfere
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
storage = OpenDALStorage(scheme="fs")
|
||||
|
||||
assert Path(env_root).is_dir()
|
||||
storage.save("test_env_root.txt", b"env_data")
|
||||
assert storage.exists("test_env_root.txt")
|
||||
assert storage.load_once("test_env_root.txt") == b"env_data"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_env_root.txt")
|
||||
@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import (
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
@ -1121,7 +1122,7 @@ class TestDatasetIndexingEstimateApi:
|
||||
def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="key",
|
||||
name="name.txt",
|
||||
size=1,
|
||||
|
||||
@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.explore.banner as banner_module
|
||||
from models.enums import BannerStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -20,7 +21,7 @@ class TestBannerApi:
|
||||
banner.content = {"text": "hello"}
|
||||
banner.link = "https://example.com"
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = datetime(2024, 1, 1)
|
||||
|
||||
query = MagicMock()
|
||||
@ -54,7 +55,7 @@ class TestBannerApi:
|
||||
banner.content = {"text": "fallback"}
|
||||
banner.link = None
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = None
|
||||
|
||||
query = MagicMock()
|
||||
|
||||
@ -39,14 +39,21 @@ class TestHitTestingPayload:
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all optional fields."""
|
||||
retrieval_model_data = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
"top_k": 5,
|
||||
}
|
||||
payload = HitTestingPayload(
|
||||
query="test query",
|
||||
retrieval_model={"top_k": 5},
|
||||
retrieval_model=retrieval_model_data,
|
||||
external_retrieval_model={"provider": "openai"},
|
||||
attachment_ids=["att_1", "att_2"],
|
||||
)
|
||||
assert payload.query == "test query"
|
||||
assert payload.retrieval_model == {"top_k": 5}
|
||||
assert payload.retrieval_model is not None
|
||||
assert payload.retrieval_model.top_k == 5
|
||||
assert payload.external_retrieval_model == {"provider": "openai"}
|
||||
assert payload.attachment_ids == ["att_1", "att_2"]
|
||||
|
||||
@ -134,7 +141,13 @@ class TestHitTestingApiPost:
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8}
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": True,
|
||||
"top_k": 10,
|
||||
"score_threshold": 0.8,
|
||||
}
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
@ -152,7 +165,11 @@ class TestHitTestingApiPost:
|
||||
|
||||
assert response["query"] == "complex query"
|
||||
call_kwargs = mock_hit_svc.retrieve.call_args
|
||||
assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model
|
||||
# retrieval_model is serialized via model_dump, verify key fields
|
||||
passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model")
|
||||
assert passed_retrieval_model is not None
|
||||
assert passed_retrieval_model["search_method"] == "semantic_search"
|
||||
assert passed_retrieval_model["top_k"] == 10
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
|
||||
@ -23,6 +23,7 @@ def mock_jsonify():
|
||||
|
||||
class DummyWebhookTrigger:
|
||||
webhook_id = "wh-1"
|
||||
webhook_url = "http://localhost:5001/triggers/webhook/wh-1"
|
||||
tenant_id = "tenant-1"
|
||||
app_id = "app-1"
|
||||
node_id = "node-1"
|
||||
@ -104,7 +105,32 @@ class TestHandleWebhookDebug:
|
||||
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
|
||||
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
|
||||
@patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1})
|
||||
@patch.object(module.TriggerDebugEventBus, "dispatch")
|
||||
@patch.object(module.TriggerDebugEventBus, "dispatch", return_value=0)
|
||||
def test_debug_requires_active_listener(
|
||||
self,
|
||||
mock_dispatch,
|
||||
mock_build_inputs,
|
||||
mock_extract,
|
||||
mock_get,
|
||||
):
|
||||
mock_get.return_value = (DummyWebhookTrigger(), None, "node_config")
|
||||
mock_extract.return_value = {"method": "POST"}
|
||||
|
||||
response, status = module.handle_webhook_debug("wh-1")
|
||||
|
||||
assert status == 409
|
||||
assert response["error"] == "No active debug listener"
|
||||
assert response["message"] == (
|
||||
"The webhook debug URL only works while the Variable Inspector is listening. "
|
||||
"Use the published webhook URL to execute the workflow in Celery."
|
||||
)
|
||||
assert response["execution_url"] == DummyWebhookTrigger.webhook_url
|
||||
mock_dispatch.assert_called_once()
|
||||
|
||||
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
|
||||
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
|
||||
@patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1})
|
||||
@patch.object(module.TriggerDebugEventBus, "dispatch", return_value=1)
|
||||
@patch.object(module.WebhookService, "generate_webhook_response")
|
||||
def test_debug_success(
|
||||
self,
|
||||
|
||||
@ -166,6 +166,7 @@ class TestDatasourceFileManager:
|
||||
# Setup
|
||||
mock_guess_ext.return_value = None # Cannot guess
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
|
||||
@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ProviderCredentialSchema,
|
||||
ProviderEntity,
|
||||
)
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
@ -409,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() -
|
||||
|
||||
configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session)
|
||||
|
||||
assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value
|
||||
assert existing_record.preferred_provider_type == ProviderType.SYSTEM
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
|
||||
id="lb-base",
|
||||
name="LB Base",
|
||||
credentials={},
|
||||
credential_source_type="provider",
|
||||
credential_source_type=CredentialSourceType.PROVIDER,
|
||||
)
|
||||
],
|
||||
),
|
||||
@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
|
||||
id="lb-custom",
|
||||
name="LB Custom",
|
||||
credentials={},
|
||||
credential_source_type="custom_model",
|
||||
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
|
||||
)
|
||||
],
|
||||
),
|
||||
@ -826,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None:
|
||||
configuration._update_load_balancing_configs_with_credential(
|
||||
credential_id="cred-1",
|
||||
credential_record=credential_record,
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
|
||||
@ -844,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non
|
||||
configuration._update_load_balancing_configs_with_credential(
|
||||
credential_id="cred-1",
|
||||
credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"),
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
@ -104,10 +104,11 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_map_known_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("map error"))
|
||||
mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"}))
|
||||
|
||||
assert app.map("https://example.com") == {}
|
||||
with pytest.raises(Exception, match="map error"):
|
||||
app.map("https://example.com")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_map_unknown_error_raises(self, mocker: MockerFixture):
|
||||
@ -177,10 +178,11 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl error"))
|
||||
mocker.patch("httpx.get", return_value=_response(500, {"error": "server"}))
|
||||
|
||||
assert app.check_crawl_status("job-1") == {}
|
||||
with pytest.raises(Exception, match="crawl error"):
|
||||
app.check_crawl_status("job-1")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture):
|
||||
@ -272,9 +274,10 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_search_known_http_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("search error"))
|
||||
mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"}))
|
||||
assert app.search("python") == {}
|
||||
with pytest.raises(Exception, match="search error"):
|
||||
app.search("python")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_search_unknown_http_error(self, mocker: MockerFixture):
|
||||
|
||||
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
@ -0,0 +1,106 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
prompt_template = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="You are a classifier.",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
|
||||
return_value=mock.MagicMock(features=[]),
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
|
||||
return_value=[SystemPromptMessage(content=content)],
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=["END"],
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=None,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
]
|
||||
)
|
||||
]
|
||||
@ -0,0 +1,63 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = build_test_graph_init_params(
|
||||
graph_config=graph_config,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", files=[]),
|
||||
user_inputs={"payload": "value"},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def _build_node_config() -> NodeConfigDict:
|
||||
return NodeConfigDictAdapter.validate_python(
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": TRIGGER_PLUGIN_NODE_TYPE,
|
||||
"title": "Trigger Event",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"subscription_id": "subscription-id",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
"event_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
id="node-1",
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
}
|
||||
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
@ -0,0 +1,19 @@
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events.base import NodeRunResult
|
||||
|
||||
|
||||
def test_node_run_result_accepts_trigger_info_metadata() -> None:
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def test_creator_user_role_missing_maps_hyphen_to_enum():
|
||||
# given an alias with hyphen
|
||||
value = "end-user"
|
||||
|
||||
# when converting to enum (invokes StrEnum._missing_ override)
|
||||
role = CreatorUserRole(value)
|
||||
|
||||
# then it should map to END_USER
|
||||
assert role is CreatorUserRole.END_USER
|
||||
|
||||
|
||||
def test_creator_user_role_missing_raises_for_unknown():
|
||||
with pytest.raises(ValueError):
|
||||
CreatorUserRole("unknown")
|
||||
@ -19,6 +19,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.enums import CredentialSourceType, PaymentStatus
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@ -158,7 +159,7 @@ class TestProviderModel:
|
||||
# Assert
|
||||
assert provider.tenant_id == tenant_id
|
||||
assert provider.provider_name == provider_name
|
||||
assert provider.provider_type == "custom"
|
||||
assert provider.provider_type == ProviderType.CUSTOM
|
||||
assert provider.is_valid is False
|
||||
assert provider.quota_used == 0
|
||||
|
||||
@ -172,10 +173,10 @@ class TestProviderModel:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name="anthropic",
|
||||
provider_type="system",
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
quota_type="paid",
|
||||
quota_type=ProviderQuotaType.PAID,
|
||||
quota_limit=10000,
|
||||
quota_used=500,
|
||||
)
|
||||
@ -183,10 +184,10 @@ class TestProviderModel:
|
||||
# Assert
|
||||
assert provider.tenant_id == tenant_id
|
||||
assert provider.provider_name == "anthropic"
|
||||
assert provider.provider_type == "system"
|
||||
assert provider.provider_type == ProviderType.SYSTEM
|
||||
assert provider.is_valid is True
|
||||
assert provider.credential_id == credential_id
|
||||
assert provider.quota_type == "paid"
|
||||
assert provider.quota_type == ProviderQuotaType.PAID
|
||||
assert provider.quota_limit == 10000
|
||||
assert provider.quota_used == 500
|
||||
|
||||
@ -199,7 +200,7 @@ class TestProviderModel:
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.provider_type == "custom"
|
||||
assert provider.provider_type == ProviderType.CUSTOM
|
||||
assert provider.is_valid is False
|
||||
assert provider.quota_type == ""
|
||||
assert provider.quota_limit is None
|
||||
@ -213,7 +214,7 @@ class TestProviderModel:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
provider_type="custom",
|
||||
provider_type=ProviderType.CUSTOM,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -253,7 +254,7 @@ class TestProviderModel:
|
||||
provider = Provider(
|
||||
tenant_id=str(uuid4()),
|
||||
provider_name="openai",
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
@ -266,13 +267,13 @@ class TestProviderModel:
|
||||
provider = Provider(
|
||||
tenant_id=str(uuid4()),
|
||||
provider_name="openai",
|
||||
quota_type="trial",
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_limit=1000,
|
||||
quota_used=250,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.quota_type == "trial"
|
||||
assert provider.quota_type == ProviderQuotaType.TRIAL
|
||||
assert provider.quota_limit == 1000
|
||||
assert provider.quota_used == 250
|
||||
remaining = provider.quota_limit - provider.quota_used
|
||||
@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider:
|
||||
preferred = TenantPreferredModelProvider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
preferred_provider_type="custom",
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert preferred.tenant_id == tenant_id
|
||||
assert preferred.provider_name == "openai"
|
||||
assert preferred.preferred_provider_type == "custom"
|
||||
assert preferred.preferred_provider_type == ProviderType.CUSTOM
|
||||
|
||||
def test_tenant_preferred_provider_system_type(self):
|
||||
"""Test tenant preferred provider with system type."""
|
||||
@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider:
|
||||
preferred = TenantPreferredModelProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
provider_name="anthropic",
|
||||
preferred_provider_type="system",
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert preferred.preferred_provider_type == "system"
|
||||
assert preferred.preferred_provider_type == ProviderType.SYSTEM
|
||||
|
||||
|
||||
class TestProviderOrder:
|
||||
@ -470,7 +471,7 @@ class TestProviderOrder:
|
||||
quantity=1,
|
||||
currency=None,
|
||||
total_amount=None,
|
||||
payment_status="wait_pay",
|
||||
payment_status=PaymentStatus.WAIT_PAY,
|
||||
paid_at=None,
|
||||
pay_failed_at=None,
|
||||
refunded_at=None,
|
||||
@ -481,7 +482,7 @@ class TestProviderOrder:
|
||||
assert order.provider_name == "openai"
|
||||
assert order.account_id == account_id
|
||||
assert order.payment_product_id == "prod_123"
|
||||
assert order.payment_status == "wait_pay"
|
||||
assert order.payment_status == PaymentStatus.WAIT_PAY
|
||||
assert order.quantity == 1
|
||||
|
||||
def test_provider_order_with_payment_details(self):
|
||||
@ -502,7 +503,7 @@ class TestProviderOrder:
|
||||
quantity=5,
|
||||
currency="USD",
|
||||
total_amount=9999,
|
||||
payment_status="paid",
|
||||
payment_status=PaymentStatus.PAID,
|
||||
paid_at=paid_time,
|
||||
pay_failed_at=None,
|
||||
refunded_at=None,
|
||||
@ -514,7 +515,7 @@ class TestProviderOrder:
|
||||
assert order.quantity == 5
|
||||
assert order.currency == "USD"
|
||||
assert order.total_amount == 9999
|
||||
assert order.payment_status == "paid"
|
||||
assert order.payment_status == PaymentStatus.PAID
|
||||
assert order.paid_at == paid_time
|
||||
|
||||
def test_provider_order_payment_statuses(self):
|
||||
@ -536,23 +537,23 @@ class TestProviderOrder:
|
||||
}
|
||||
|
||||
# Act & Assert - Wait pay status
|
||||
wait_order = ProviderOrder(**base_params, payment_status="wait_pay")
|
||||
assert wait_order.payment_status == "wait_pay"
|
||||
wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY)
|
||||
assert wait_order.payment_status == PaymentStatus.WAIT_PAY
|
||||
|
||||
# Act & Assert - Paid status
|
||||
paid_order = ProviderOrder(**base_params, payment_status="paid")
|
||||
assert paid_order.payment_status == "paid"
|
||||
paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID)
|
||||
assert paid_order.payment_status == PaymentStatus.PAID
|
||||
|
||||
# Act & Assert - Failed status
|
||||
failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)}
|
||||
failed_order = ProviderOrder(**failed_params, payment_status="failed")
|
||||
assert failed_order.payment_status == "failed"
|
||||
failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED)
|
||||
assert failed_order.payment_status == PaymentStatus.FAILED
|
||||
assert failed_order.pay_failed_at is not None
|
||||
|
||||
# Act & Assert - Refunded status
|
||||
refunded_params = {**base_params, "refunded_at": datetime.now(UTC)}
|
||||
refunded_order = ProviderOrder(**refunded_params, payment_status="refunded")
|
||||
assert refunded_order.payment_status == "refunded"
|
||||
refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED)
|
||||
assert refunded_order.payment_status == PaymentStatus.REFUNDED
|
||||
assert refunded_order.refunded_at is not None
|
||||
|
||||
|
||||
@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig:
|
||||
name="Secondary API Key",
|
||||
encrypted_config='{"api_key": "encrypted_value"}',
|
||||
credential_id=credential_id,
|
||||
credential_source_type="custom",
|
||||
credential_source_type=CredentialSourceType.CUSTOM_MODEL,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert config.encrypted_config == '{"api_key": "encrypted_value"}'
|
||||
assert config.credential_id == credential_id
|
||||
assert config.credential_source_type == "custom"
|
||||
assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL
|
||||
|
||||
def test_load_balancing_config_disabled(self):
|
||||
"""Test disabled load balancing config."""
|
||||
|
||||
@ -443,7 +443,7 @@ def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monk
|
||||
|
||||
def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed"}
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 1, "current": 1, "data": []}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
redis_mock = MagicMock()
|
||||
|
||||
65
api/uv.lock
generated
65
api/uv.lock
generated
@ -1533,7 +1533,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.13.1"
|
||||
version = "1.13.2"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
@ -5405,11 +5405,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -7248,30 +7248,43 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "ujson"
|
||||
version = "5.9.0"
|
||||
version = "5.12.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6e/54/6f2bdac7117e89a47de4511c9f01732a283457ab1bf856e1e51aa861619e/ujson-5.9.0.tar.gz", hash = "sha256:89cc92e73d5501b8a7f48575eeb14ad27156ad092c2e9fc7e3cf949f07e75532", size = 7154214, upload-time = "2023-12-10T22:50:34.812Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cb/3e/c35530c5ffc25b71c59ae0cd7b8f99df37313daa162ce1e2f7925f7c2877/ujson-5.12.0.tar.gz", hash = "sha256:14b2e1eb528d77bc0f4c5bd1a7ebc05e02b5b41beefb7e8567c9675b8b13bcf4", size = 7158451, upload-time = "2026-03-11T22:19:30.397Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/ca/ae3a6ca5b4f82ce654d6ac3dde5e59520537e20939592061ba506f4e569a/ujson-5.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b23bbb46334ce51ddb5dded60c662fbf7bb74a37b8f87221c5b0fec1ec6454b", size = 57753, upload-time = "2023-12-10T22:49:03.939Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/5f/c27fa9a1562c96d978c39852b48063c3ca480758f3088dcfc0f3b09f8e93/ujson-5.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6974b3a7c17bbf829e6c3bfdc5823c67922e44ff169851a755eab79a3dd31ec0", size = 54092, upload-time = "2023-12-10T22:49:05.194Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/f3/1431713de9e5992e5e33ba459b4de28f83904233958855d27da820a101f9/ujson-5.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5964ea916edfe24af1f4cc68488448fbb1ec27a3ddcddc2b236da575c12c8ae", size = 51675, upload-time = "2023-12-10T22:49:06.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/93/de6fff3ae06351f3b1c372f675fe69bc180f93d237c9e496c05802173dd6/ujson-5.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba7cac47dd65ff88571eceeff48bf30ed5eb9c67b34b88cb22869b7aa19600d", size = 53246, upload-time = "2023-12-10T22:49:07.691Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/73/db509fe1d7da62a15c0769c398cec66bdfc61a8bdffaf7dfa9d973e3d65c/ujson-5.9.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6bbd91a151a8f3358c29355a491e915eb203f607267a25e6ab10531b3b157c5e", size = 58182, upload-time = "2023-12-10T22:49:08.89Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/a8/6be607fa3e1fa3e1c9b53f5de5acad33b073b6cc9145803e00bcafa729a8/ujson-5.9.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:829a69d451a49c0de14a9fecb2a2d544a9b2c884c2b542adb243b683a6f15908", size = 584493, upload-time = "2023-12-10T22:49:11.043Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/c7/33822c2f1a8175e841e2bc378ffb2c1109ce9280f14cedb1b2fa0caf3145/ujson-5.9.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a807ae73c46ad5db161a7e883eec0fbe1bebc6a54890152ccc63072c4884823b", size = 656038, upload-time = "2023-12-10T22:49:12.651Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/b8/5309fbb299d5fcac12bbf3db20896db5178392904abe6b992da233dc69d6/ujson-5.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8fc2aa18b13d97b3c8ccecdf1a3c405f411a6e96adeee94233058c44ff92617d", size = 597643, upload-time = "2023-12-10T22:49:14.883Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/64/7b63043b95dd78feed401b9973958af62645a6d19b72b6e83d1ea5af07e0/ujson-5.9.0-cp311-cp311-win32.whl", hash = "sha256:70e06849dfeb2548be48fdd3ceb53300640bc8100c379d6e19d78045e9c26120", size = 38342, upload-time = "2023-12-10T22:49:16.854Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/13/a3cd1fc3a1126d30b558b6235c05e2d26eeaacba4979ee2fd2b5745c136d/ujson-5.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:7309d063cd392811acc49b5016728a5e1b46ab9907d321ebbe1c2156bc3c0b99", size = 41923, upload-time = "2023-12-10T22:49:17.983Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/7e/c37fca6cd924931fa62d615cdbf5921f34481085705271696eff38b38867/ujson-5.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:20509a8c9f775b3a511e308bbe0b72897ba6b800767a7c90c5cca59d20d7c42c", size = 57834, upload-time = "2023-12-10T22:49:19.799Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/44/2753e902ee19bf6ccaf0bda02f1f0037f92a9769a5d31319905e3de645b4/ujson-5.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b28407cfe315bd1b34f1ebe65d3bd735d6b36d409b334100be8cdffae2177b2f", size = 54119, upload-time = "2023-12-10T22:49:21.039Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/06/2317433e394450bc44afe32b6c39d5a51014da4c6f6cfc2ae7bf7b4a2922/ujson-5.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d302bd17989b6bd90d49bade66943c78f9e3670407dbc53ebcf61271cadc399", size = 51658, upload-time = "2023-12-10T22:49:22.494Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/3a/2acf0da085d96953580b46941504aa3c91a1dd38701b9e9bfa43e2803467/ujson-5.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f21315f51e0db8ee245e33a649dd2d9dce0594522de6f278d62f15f998e050e", size = 53370, upload-time = "2023-12-10T22:49:24.045Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/32/737e6c4b1841720f88ae88ec91f582dc21174bd40742739e1fa16a0c9ffa/ujson-5.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5635b78b636a54a86fdbf6f027e461aa6c6b948363bdf8d4fbb56a42b7388320", size = 58278, upload-time = "2023-12-10T22:49:25.261Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/dc/3fda97f1ad070ccf2af597fb67dde358bc698ffecebe3bc77991d60e4fe5/ujson-5.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82b5a56609f1235d72835ee109163c7041b30920d70fe7dac9176c64df87c164", size = 584418, upload-time = "2023-12-10T22:49:27.573Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/57/e4083d774fcd8ff3089c0ff19c424abe33f23e72c6578a8172bf65131992/ujson-5.9.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5ca35f484622fd208f55041b042d9d94f3b2c9c5add4e9af5ee9946d2d30db01", size = 656126, upload-time = "2023-12-10T22:49:29.509Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/c3/8c6d5f6506ca9fcedd5a211e30a7d5ee053dc05caf23dae650e1f897effb/ujson-5.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:829b824953ebad76d46e4ae709e940bb229e8999e40881338b3cc94c771b876c", size = 597795, upload-time = "2023-12-10T22:49:31.029Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/5a/a231f0cd305a34cf2d16930304132db3a7a8c3997b367dd38fc8f8dfae36/ujson-5.9.0-cp312-cp312-win32.whl", hash = "sha256:25fa46e4ff0a2deecbcf7100af3a5d70090b461906f2299506485ff31d9ec437", size = 38495, upload-time = "2023-12-10T22:49:33.2Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/b7/18b841b44760ed298acdb150608dccdc045c41655e0bae4441f29bcab872/ujson-5.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:60718f1720a61560618eff3b56fd517d107518d3c0160ca7a5a66ac949c6cf1c", size = 42088, upload-time = "2023-12-10T22:49:34.921Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/22/fd22e2f6766bae934d3050517ca47d463016bd8688508d1ecc1baa18a7ad/ujson-5.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58a11cb49482f1a095a2bd9a1d81dd7c8fb5d2357f959ece85db4e46a825fd00", size = 56139, upload-time = "2026-03-11T22:18:04.591Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/fd/6839adff4fc0164cbcecafa2857ba08a6eaeedd7e098d6713cb899a91383/ujson-5.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9b3cf13facf6f77c283af0e1713e5e8c47a0fe295af81326cb3cb4380212e797", size = 53836, upload-time = "2026-03-11T22:18:05.662Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/b0/0c19faac62d68ceeffa83a08dc3d71b8462cf5064d0e7e0b15ba19898dad/ujson-5.12.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb94245a715b4d6e24689de12772b85329a1f9946cbf6187923a64ecdea39e65", size = 57851, upload-time = "2026-03-11T22:18:06.744Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/f6/e7fd283788de73b86e99e08256726bb385923249c21dcd306e59d532a1a1/ujson-5.12.0-cp311-cp311-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:0fe6b8b8968e11dd9b2348bd508f0f57cf49ab3512064b36bc4117328218718e", size = 59906, upload-time = "2026-03-11T22:18:07.791Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/3a/b100735a2b43ee6e8fe4c883768e362f53576f964d4ea841991060aeaf35/ujson-5.12.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:89e302abd3749f6d6699691747969a5d85f7c73081d5ed7e2624c7bd9721a2ab", size = 57409, upload-time = "2026-03-11T22:18:08.79Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/fa/f97cc20c99ca304662191b883ae13ae02912ca7244710016ba0cb8a5be34/ujson-5.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0727363b05ab05ee737a28f6200dc4078bce6b0508e10bd8aab507995a15df61", size = 1037339, upload-time = "2026-03-11T22:18:10.424Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/7a/53ddeda0ffe1420db2f9999897b3cbb920fbcff1849d1f22b196d0f34785/ujson-5.12.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b62cb9a7501e1f5c9ffe190485501349c33e8862dde4377df774e40b8166871f", size = 1196625, upload-time = "2026-03-11T22:18:11.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/1a/4c64a6bef522e9baf195dd5be151bc815cd4896c50c6e2489599edcda85f/ujson-5.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a6ec5bf6bc361f2f0f9644907a36ce527715b488988a8df534120e5c34eeda94", size = 1089669, upload-time = "2026-03-11T22:18:13.343Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/11/8ccb109f5777ec0d9fb826695a9e2ac36ae94c1949fc8b1e4d23a5bd067a/ujson-5.12.0-cp311-cp311-win32.whl", hash = "sha256:006428d3813b87477d72d306c40c09f898a41b968e57b15a7d88454ecc42a3fb", size = 39648, upload-time = "2026-03-11T22:18:14.785Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/e3/87fc4c27b20d5125cff7ce52d17ea7698b22b74426da0df238e3efcb0cf2/ujson-5.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:40aa43a7a3a8d2f05e79900858053d697a88a605e3887be178b43acbcd781161", size = 43876, upload-time = "2026-03-11T22:18:15.768Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/21/324f0548a8c8c48e3e222eaed15fb6d48c796593002b206b4a28a89e445f/ujson-5.12.0-cp311-cp311-win_arm64.whl", hash = "sha256:561f89cc82deeae82e37d4a4764184926fb432f740a9691563a391b13f7339a4", size = 38553, upload-time = "2026-03-11T22:18:17.251Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/f6/ac763d2108d28f3a40bb3ae7d2fafab52ca31b36c2908a4ad02cd3ceba2a/ujson-5.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:09b4beff9cc91d445d5818632907b85fb06943b61cb346919ce202668bf6794a", size = 56326, upload-time = "2026-03-11T22:18:18.467Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/46/d0b3af64dcdc549f9996521c8be6d860ac843a18a190ffc8affeb7259687/ujson-5.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ca0c7ce828bb76ab78b3991904b477c2fd0f711d7815c252d1ef28ff9450b052", size = 53910, upload-time = "2026-03-11T22:18:19.502Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/10/853c723bcabc3e9825a079019055fc99e71b85c6bae600607a2b9d31d18d/ujson-5.12.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2d79c6635ccffcbfc1d5c045874ba36b594589be81d50d43472570bb8de9c57", size = 57754, upload-time = "2026-03-11T22:18:20.874Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/c6/6e024830d988f521f144ead641981c1f7a82c17ad1927c22de3242565f5c/ujson-5.12.0-cp312-cp312-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:7e07f6f644d2c44d53b7a320a084eef98063651912c1b9449b5f45fcbdc6ccd2", size = 59936, upload-time = "2026-03-11T22:18:21.924Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/c9/c5f236af5abe06b720b40b88819d00d10182d2247b1664e487b3ed9229cf/ujson-5.12.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:085b6ce182cdd6657481c7c4003a417e0655c4f6e58b76f26ee18f0ae21db827", size = 57463, upload-time = "2026-03-11T22:18:22.924Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/04/41342d9ef68e793a87d84e4531a150c2b682f3bcedfe59a7a5e3f73e9213/ujson-5.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:16b4fe9c97dc605f5e1887a9e1224287291e35c56cbc379f8aa44b6b7bcfe2bb", size = 1037239, upload-time = "2026-03-11T22:18:24.04Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/81/dc2b7617d5812670d4ff4a42f6dd77926430ee52df0dedb2aec7990b2034/ujson-5.12.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0d2e8db5ade3736a163906154ca686203acc7d1d30736cbf577c730d13653d84", size = 1196713, upload-time = "2026-03-11T22:18:25.391Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/9c/80acff0504f92459ed69e80a176286e32ca0147ac6a8252cd0659aad3227/ujson-5.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:93bc91fdadcf046da37a214eaa714574e7e9b1913568e93bb09527b2ceb7f759", size = 1089742, upload-time = "2026-03-11T22:18:26.738Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/f0/123ffaac17e45ef2b915e3e3303f8f4ea78bb8d42afad828844e08622b1e/ujson-5.12.0-cp312-cp312-win32.whl", hash = "sha256:2a248750abce1c76fbd11b2e1d88b95401e72819295c3b851ec73399d6849b3d", size = 39773, upload-time = "2026-03-11T22:18:28.244Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/20/f3bd2b069c242c2b22a69e033bfe224d1d15d3649e6cd7cc7085bb1412ff/ujson-5.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:1b5c6ceb65fecd28a1d20d1eba9dbfa992612b86594e4b6d47bb580d2dd6bcb3", size = 44040, upload-time = "2026-03-11T22:18:29.236Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/a7/01b5a0bcded14cd2522b218f2edc3533b0fcbccdea01f3e14a2b699071aa/ujson-5.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:9a5fcbe7b949f2e95c47ea8a80b410fcdf2da61c98553b45a4ee875580418b68", size = 38526, upload-time = "2026-03-11T22:18:30.551Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/3c/5ee154d505d1aad2debc4ba38b1a60ae1949b26cdb5fa070e85e320d6b64/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_10_13_x86_64.whl", hash = "sha256:bf85a00ac3b56a1e7a19c5be7b02b5180a0895ac4d3c234d717a55e86960691c", size = 54494, upload-time = "2026-03-11T22:19:13.035Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/b3/9496ec399ec921e434a93b340bd5052999030b7ac364be4cbe5365ac6b20/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:64df53eef4ac857eb5816a56e2885ccf0d7dff6333c94065c93b39c51063e01d", size = 57999, upload-time = "2026-03-11T22:19:14.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/da/e9ae98133336e7c0d50b43626c3f2327937cecfa354d844e02ac17379ed1/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c0aed6a4439994c9666fb8a5b6c4eac94d4ef6ddc95f9b806a599ef83547e3b", size = 54518, upload-time = "2026-03-11T22:19:15.4Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/10/978d89dded6bb1558cd46ba78f4351198bd2346db8a8ee1a94119022ce40/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efae5df7a8cc8bdb1037b0f786b044ce281081441df5418c3a0f0e1f86fe7bb3", size = 55736, upload-time = "2026-03-11T22:19:16.496Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/25/1df8e6217c92e57a1266bf5be750b1dddc126ee96e53fe959d5693503bc6/ujson-5.12.0-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:8712b61eb1b74a4478cfd1c54f576056199e9f093659334aeb5c4a6b385338e5", size = 44615, upload-time = "2026-03-11T22:19:17.53Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/fa/f4a957dddb99bd68c8be91928c0b6fefa7aa8aafc92c93f5d1e8b32f6702/ujson-5.12.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:871c0e5102e47995b0e37e8df7819a894a6c3da0d097545cd1f9f1f7d7079927", size = 52145, upload-time = "2026-03-11T22:19:18.566Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/6e/50b5cf612de1ca06c7effdc5a5d7e815774dee85a5858f1882c425553b82/ujson-5.12.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:56ba3f7abbd6b0bb282a544dc38406d1a188d8bb9164f49fdb9c2fee62cb29da", size = 49577, upload-time = "2026-03-11T22:19:19.627Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/24/b6713fa9897774502cd4c2d6955bb4933349f7d84c3aa805531c382a4209/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c5a52987a990eb1bae55f9000994f1afdb0326c154fb089992f839ab3c30688", size = 50807, upload-time = "2026-03-11T22:19:20.778Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/b6/c0e0f7901180ef80d16f3a4bccb5dc8b01515a717336a62928963a07b80b/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:adf28d13a33f9d750fe7a78fb481cac298fa257d8863d8727b2ea4455ea41235", size = 56972, upload-time = "2026-03-11T22:19:21.84Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/02/a9/05d91b4295ea7239151eb08cf240e5a2ba969012fda50bc27bcb1ea9cd71/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51acc750ec7a2df786cdc868fb16fa04abd6269a01d58cf59bafc57978773d8e", size = 52045, upload-time = "2026-03-11T22:19:22.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/7a/92047d32bf6f2d9db64605fc32e8eb0e0dd68b671eaafc12a464f69c4af4/ujson-5.12.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:ab9056d94e5db513d9313b34394f3a3b83e6301a581c28ad67773434f3faccab", size = 44053, upload-time = "2026-03-11T22:19:23.918Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
16
codecov.yml
Normal file
16
codecov.yml
Normal file
@ -0,0 +1,16 @@
|
||||
coverage:
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
target: auto
|
||||
|
||||
flags:
|
||||
web:
|
||||
paths:
|
||||
- "web/"
|
||||
carryforward: true
|
||||
|
||||
api:
|
||||
paths:
|
||||
- "api/"
|
||||
carryforward: true
|
||||
@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@ -728,7 +728,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -770,7 +770,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -809,7 +809,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -839,7 +839,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@ -29,7 +29,7 @@ const mockOnPlanInfoChanged = vi.fn()
|
||||
const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined)
|
||||
let mockDeleteMutationPending = false
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockRouterPush,
|
||||
}),
|
||||
@ -57,7 +57,7 @@ vi.mock('@headlessui/react', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('next/dynamic', () => ({
|
||||
vi.mock('@/next/dynamic', () => ({
|
||||
default: (loader: () => Promise<{ default: React.ComponentType }>) => {
|
||||
let Component: React.ComponentType<Record<string, unknown>> | null = null
|
||||
loader().then((mod) => {
|
||||
|
||||
@ -38,7 +38,7 @@ let mockShowTagManagementModal = false
|
||||
const mockRouterPush = vi.fn()
|
||||
const mockRouterReplace = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockRouterPush,
|
||||
replace: mockRouterReplace,
|
||||
@ -46,7 +46,7 @@ vi.mock('next/navigation', () => ({
|
||||
useSearchParams: () => new URLSearchParams(),
|
||||
}))
|
||||
|
||||
vi.mock('next/dynamic', () => ({
|
||||
vi.mock('@/next/dynamic', () => ({
|
||||
default: (_loader: () => Promise<{ default: React.ComponentType }>) => {
|
||||
const LazyComponent = (props: Record<string, unknown>) => {
|
||||
return <div data-testid="dynamic-component" {...props} />
|
||||
|
||||
@ -35,7 +35,7 @@ const mockRouterPush = vi.fn()
|
||||
const mockRouterReplace = vi.fn()
|
||||
const mockOnPlanInfoChanged = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockRouterPush,
|
||||
replace: mockRouterReplace,
|
||||
@ -117,7 +117,7 @@ vi.mock('ahooks', async () => {
|
||||
})
|
||||
|
||||
// Mock dynamically loaded modals with test stubs
|
||||
vi.mock('next/dynamic', () => ({
|
||||
vi.mock('@/next/dynamic', () => ({
|
||||
default: (loader: () => Promise<{ default: React.ComponentType }>) => {
|
||||
let Component: React.ComponentType<Record<string, unknown>> | null = null
|
||||
loader().then((mod) => {
|
||||
|
||||
@ -64,7 +64,7 @@ vi.mock('@/service/use-education', () => ({
|
||||
|
||||
// ─── Navigation mocks ───────────────────────────────────────────────────────
|
||||
const mockRouterPush = vi.fn()
|
||||
vi.mock('next/navigation', () => ({
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({ push: mockRouterPush }),
|
||||
usePathname: () => '/billing',
|
||||
useSearchParams: () => new URLSearchParams(),
|
||||
|
||||
@ -11,6 +11,7 @@ import type { BasicPlan } from '@/app/components/billing/type'
|
||||
import { cleanup, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { ALL_PLANS } from '@/app/components/billing/config'
|
||||
import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher'
|
||||
import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item'
|
||||
@ -21,7 +22,6 @@ let mockAppCtx: Record<string, unknown> = {}
|
||||
const mockFetchSubscriptionUrls = vi.fn()
|
||||
const mockInvoices = vi.fn()
|
||||
const mockOpenAsyncWindow = vi.fn()
|
||||
const mockToastNotify = vi.fn()
|
||||
|
||||
// ─── Context mocks ───────────────────────────────────────────────────────────
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
@ -49,12 +49,8 @@ vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: () => mockOpenAsyncWindow,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: (args: unknown) => mockToastNotify(args) },
|
||||
}))
|
||||
|
||||
// ─── Navigation mocks ───────────────────────────────────────────────────────
|
||||
vi.mock('next/navigation', () => ({
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
usePathname: () => '/billing',
|
||||
useSearchParams: () => new URLSearchParams(),
|
||||
@ -82,12 +78,15 @@ const renderCloudPlanItem = ({
|
||||
canPay = true,
|
||||
}: RenderCloudPlanItemOptions = {}) => {
|
||||
return render(
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>,
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
@ -96,6 +95,7 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
setupAppContext()
|
||||
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' })
|
||||
mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' })
|
||||
@ -283,11 +283,7 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
// Should not proceed with payment
|
||||
expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user