mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
Merge remote-tracking branch 'origin/main' into verify-email-reset-flow
This commit is contained in:
commit
f9b2ff59c8
8
.github/actions/setup-web/action.yml
vendored
8
.github/actions/setup-web/action.yml
vendored
@ -4,10 +4,10 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
|
||||
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
|
||||
with:
|
||||
node-version-file: "./web/.nvmrc"
|
||||
node-version-file: web/.nvmrc
|
||||
cache: true
|
||||
cache-dependency-path: web/pnpm-lock.yaml
|
||||
run-install: |
|
||||
- cwd: ./web
|
||||
args: ['--frozen-lockfile']
|
||||
cwd: ./web
|
||||
|
||||
2
.github/workflows/anti-slop.yml
vendored
2
.github/workflows/anti-slop.yml
vendored
@ -12,7 +12,7 @@ jobs:
|
||||
anti-slop:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: peakoss/anti-slop@v0
|
||||
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
|
||||
38
.github/workflows/api-tests.yml
vendored
38
.github/workflows/api-tests.yml
vendored
@ -2,6 +2,12 @@ name: Run Pytest
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -11,6 +17,8 @@ jobs:
|
||||
test:
|
||||
name: API Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@ -24,10 +32,11 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -79,21 +88,12 @@ jobs:
|
||||
api/tests/test_containers_integration_tests \
|
||||
api/tests/unit_tests
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
{
|
||||
echo ""
|
||||
echo "<details><summary>File-level coverage (click to expand)</summary>"
|
||||
echo ""
|
||||
echo '```'
|
||||
uv run --project api coverage report -m
|
||||
echo '```'
|
||||
echo "</details>"
|
||||
} >> $GITHUB_STEP_SUMMARY
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
flags: api
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
6
.github/workflows/main-ci.yml
vendored
6
.github/workflows/main-ci.yml
vendored
@ -56,16 +56,14 @@ jobs:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
with:
|
||||
base_sha: ${{ github.event.before || github.event.pull_request.base.sha }}
|
||||
diff_range_mode: ${{ github.event.before && 'exact' || 'merge-base' }}
|
||||
head_sha: ${{ github.event.after || github.event.pull_request.head.sha || github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
|
||||
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
65
.github/workflows/web-tests.yml
vendored
65
.github/workflows/web-tests.yml
vendored
@ -2,16 +2,9 @@ name: Web Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
base_sha:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
type: string
|
||||
diff_range_mode:
|
||||
required: false
|
||||
type: string
|
||||
head_sha:
|
||||
required: false
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@ -63,7 +56,7 @@ jobs:
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@ -87,52 +80,16 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
run: vp test --merge-reports --coverage --silent=passed-only
|
||||
|
||||
- name: Report app/components baseline coverage
|
||||
run: node ./scripts/report-components-coverage-baseline.mjs
|
||||
|
||||
- name: Report app/components test touch
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/report-components-test-touch.mjs
|
||||
|
||||
- name: Check app/components pure diff coverage
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Check Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
web-build:
|
||||
name: Web Build
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
@ -48,14 +50,15 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(ToolOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(ToolOAuthSystemClient).where(
|
||||
ToolOAuthSystemClient.provider == provider_name,
|
||||
ToolOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@ -97,14 +100,15 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(TriggerOAuthSystemClient).where(
|
||||
TriggerOAuthSystemClient.provider == provider_name,
|
||||
TriggerOAuthSystemClient.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@ -139,14 +143,15 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
deleted_count = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(DatasourceOauthParamConfig).where(
|
||||
DatasourceOauthParamConfig.provider == provider_name,
|
||||
DatasourceOauthParamConfig.plugin_id == plugin_id,
|
||||
)
|
||||
),
|
||||
).rowcount
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
@ -192,7 +197,9 @@ def transform_datasource_credentials(environment: str):
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
notion_credentials = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion")
|
||||
).all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for notion_credential in notion_credentials:
|
||||
@ -201,7 +208,7 @@ def transform_datasource_credentials(environment: str):
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
@ -250,7 +257,9 @@ def transform_datasource_credentials(environment: str):
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
firecrawl_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl")
|
||||
).all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
@ -259,7 +268,7 @@ def transform_datasource_credentials(environment: str):
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
@ -312,7 +321,9 @@ def transform_datasource_credentials(environment: str):
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
jina_credentials = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader")
|
||||
).all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for jina_credential in jina_credentials:
|
||||
@ -321,7 +332,7 @@ def transform_datasource_credentials(environment: str):
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@ -740,14 +743,17 @@ def migrate_oss(
|
||||
else:
|
||||
try:
|
||||
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
|
||||
updated = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
|
||||
)
|
||||
updated = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
update(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.values(storage_type=dify_config.STORAGE_TYPE)
|
||||
),
|
||||
).rowcount
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -41,7 +42,7 @@ def reset_encrypt_key_pair():
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
tenants = session.scalars(select(Tenant)).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
@ -49,8 +50,8 @@ def reset_encrypt_key_pair():
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
|
||||
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -93,7 +94,7 @@ def convert_to_agent_apps():
|
||||
app_id = str(i.id)
|
||||
if app_id not in proceeded_app_ids:
|
||||
proceeded_app_ids.append(app_id)
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
if app is not None:
|
||||
apps.append(app)
|
||||
|
||||
@ -108,8 +109,8 @@ def convert_to_agent_apps():
|
||||
db.session.commit()
|
||||
|
||||
# update conversation mode to agent
|
||||
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
|
||||
{Conversation.mode: AppMode.AGENT_CHAT}
|
||||
db.session.execute(
|
||||
update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
@ -177,7 +178,7 @@ where sites.id is null limit 1000"""
|
||||
continue
|
||||
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id))
|
||||
if not app:
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
@ -41,14 +41,13 @@ def migrate_annotation_vector_database():
|
||||
# get apps info
|
||||
per_page = 50
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
apps = session.scalars(
|
||||
select(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@ -63,8 +62,8 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1)
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
@ -72,10 +71,10 @@ def migrate_annotation_vector_database():
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
dataset_collection_binding = session.scalar(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
@ -205,11 +204,11 @@ def migrate_knowledge_vector_database():
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
dataset_collection_binding = db.session.execute(
|
||||
select(DatasetCollectionBinding).where(
|
||||
DatasetCollectionBinding.id == dataset.collection_binding_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
@ -334,7 +333,7 @@ def add_qdrant_index(field: str):
|
||||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
bindings = db.session.scalars(select(DatasetCollectionBinding)).all()
|
||||
if not bindings:
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
@ -421,10 +420,10 @@ def old_metadata_migration():
|
||||
if field.value == key:
|
||||
break
|
||||
else:
|
||||
dataset_metadata = (
|
||||
db.session.query(DatasetMetadata)
|
||||
dataset_metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_metadata:
|
||||
dataset_metadata = DatasetMetadata(
|
||||
@ -436,7 +435,7 @@ def old_metadata_migration():
|
||||
)
|
||||
db.session.add(dataset_metadata)
|
||||
db.session.flush()
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding(
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
metadata_id=dataset_metadata.id,
|
||||
@ -445,14 +444,14 @@ def old_metadata_migration():
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
else:
|
||||
dataset_metadata_binding = (
|
||||
db.session.query(DatasetMetadataBinding) # type: ignore
|
||||
dataset_metadata_binding = db.session.scalar(
|
||||
select(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||
DatasetMetadataBinding.document_id == document.id,
|
||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_metadata_binding:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
|
||||
@ -103,13 +103,13 @@ class AppMCPServerController(Resource):
|
||||
raise NotFound()
|
||||
|
||||
description = payload.description
|
||||
if description is None:
|
||||
pass
|
||||
elif not description:
|
||||
if description is None or not description:
|
||||
server.description = app_model.description or ""
|
||||
else:
|
||||
server.description = description
|
||||
|
||||
server.name = app_model.name
|
||||
|
||||
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||
if payload.status:
|
||||
try:
|
||||
|
||||
@ -298,6 +298,7 @@ class DatasetDocumentListApi(Resource):
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
@ -24,6 +24,7 @@ from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from flask_restx import Resource
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.enums import BannerStatus
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
@ -16,7 +17,7 @@ class BannerApi(Resource):
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
|
||||
@ -517,7 +517,7 @@ class WorkflowResponseConverter:
|
||||
snapshot = self._pop_snapshot(event.node_execution_id)
|
||||
|
||||
start_at = snapshot.start_at if snapshot else event.start_at
|
||||
finished_at = naive_utc_now()
|
||||
finished_at = event.finished_at or naive_utc_now()
|
||||
elapsed_time = (finished_at - start_at).total_seconds()
|
||||
|
||||
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
|
||||
|
||||
@ -456,6 +456,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@ -471,6 +472,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
@ -487,6 +489,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
|
||||
@ -335,6 +335,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -390,6 +391,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -414,6 +416,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
finished_at: datetime | None = None
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
*,
|
||||
error: str | None = None,
|
||||
update_outputs: bool = True,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
finished_at = naive_utc_now()
|
||||
actual_finished_at = finished_at or naive_utc_now()
|
||||
snapshot = self._node_snapshots.get(domain_execution.id)
|
||||
start_at = snapshot.created_at if snapshot else domain_execution.created_at
|
||||
domain_execution.status = status
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
|
||||
domain_execution.finished_at = actual_finished_at
|
||||
domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0)
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
@ -15,6 +15,7 @@ from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
@ -81,7 +82,7 @@ class DatasourceFileManager:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=filepath,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
|
||||
@ -1422,12 +1422,12 @@ class ProviderConfiguration(BaseModel):
|
||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
preferred_model_provider.preferred_provider_type = provider_type
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value,
|
||||
preferred_provider_type=provider_type,
|
||||
)
|
||||
s.add(preferred_model_provider)
|
||||
s.commit()
|
||||
|
||||
@ -195,7 +195,7 @@ class ProviderManager:
|
||||
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
preferred_provider_type = preferred_provider_type_record.preferred_provider_type
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
|
||||
@ -68,9 +68,12 @@ class SegmentRecord(TypedDict):
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod | str
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@ -1,12 +1,38 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any, cast
|
||||
from typing import Any, NotRequired, cast
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class FirecrawlDocumentData(TypedDict):
|
||||
title: str | None
|
||||
description: str | None
|
||||
source_url: str | None
|
||||
markdown: str | None
|
||||
|
||||
|
||||
class CrawlStatusResponse(TypedDict):
|
||||
status: str
|
||||
total: int | None
|
||||
current: int | None
|
||||
data: list[FirecrawlDocumentData]
|
||||
|
||||
|
||||
class MapResponse(TypedDict):
|
||||
success: bool
|
||||
links: list[str]
|
||||
|
||||
|
||||
class SearchResponse(TypedDict):
|
||||
success: bool
|
||||
data: list[dict[str, Any]]
|
||||
warning: NotRequired[str]
|
||||
|
||||
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
@ -14,7 +40,7 @@ class FirecrawlApp:
|
||||
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict[str, Any]:
|
||||
def scrape_url(self, url, params=None) -> FirecrawlDocumentData:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
@ -32,9 +58,7 @@ class FirecrawlApp:
|
||||
return self._extract_common_fields(data)
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "scrape URL")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
|
||||
@ -51,7 +75,7 @@ class FirecrawlApp:
|
||||
self._handle_error(response, "start crawl job")
|
||||
return "" # unreachable
|
||||
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
def map(self, url: str, params: dict[str, Any] | None = None) -> MapResponse:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
|
||||
headers = self._prepare_headers()
|
||||
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
|
||||
@ -60,14 +84,12 @@ class FirecrawlApp:
|
||||
json_data.update(params)
|
||||
response = self._post_request(self._build_url("v2/map"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
return cast(dict[str, Any], response.json())
|
||||
return cast(MapResponse, response.json())
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "start map job")
|
||||
return {}
|
||||
else:
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
def check_crawl_status(self, job_id) -> CrawlStatusResponse:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
|
||||
if response.status_code == 200:
|
||||
@ -77,7 +99,7 @@ class FirecrawlApp:
|
||||
if total == 0:
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list = []
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = self._extract_common_fields(item)
|
||||
@ -95,13 +117,15 @@ class FirecrawlApp:
|
||||
return self._format_crawl_status_response(
|
||||
crawl_status_response.get("status"), crawl_status_response, []
|
||||
)
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
return {} # unreachable
|
||||
self._handle_error(response, "check crawl status")
|
||||
raise RuntimeError("unreachable: _handle_error always raises")
|
||||
|
||||
def _format_crawl_status_response(
|
||||
self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
self,
|
||||
status: str,
|
||||
crawl_status_response: dict[str, Any],
|
||||
url_data_list: list[FirecrawlDocumentData],
|
||||
) -> CrawlStatusResponse:
|
||||
return {
|
||||
"status": status,
|
||||
"total": crawl_status_response.get("total"),
|
||||
@ -109,7 +133,7 @@ class FirecrawlApp:
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]:
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> FirecrawlDocumentData:
|
||||
return {
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
@ -117,7 +141,7 @@ class FirecrawlApp:
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
|
||||
def _prepare_headers(self) -> dict[str, Any]:
|
||||
def _prepare_headers(self) -> dict[str, str]:
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
@ -150,10 +174,10 @@ class FirecrawlApp:
|
||||
error_message = response.text or "Unknown error occurred"
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
||||
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> SearchResponse:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
json_data: dict[str, Any] = {
|
||||
"query": query,
|
||||
"limit": 5,
|
||||
"lang": "en",
|
||||
@ -170,12 +194,10 @@ class FirecrawlApp:
|
||||
json_data.update(params)
|
||||
response = self._post_request(self._build_url("v2/search"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
response_data: SearchResponse = response.json()
|
||||
if not response_data.get("success"):
|
||||
raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
|
||||
return cast(dict[str, Any], response_data)
|
||||
return response_data
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "perform search")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
|
||||
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
|
||||
|
||||
@ -15,6 +15,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -150,7 +151,7 @@ class PdfExtractor(BaseExtractor):
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=len(img_bytes),
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlAuthenticationError,
|
||||
@ -13,6 +14,27 @@ from core.rag.extractor.watercrawl.exceptions import (
|
||||
)
|
||||
|
||||
|
||||
class SpiderOptions(TypedDict):
|
||||
max_depth: int
|
||||
page_limit: int
|
||||
allowed_domains: list[str]
|
||||
exclude_paths: list[str]
|
||||
include_paths: list[str]
|
||||
|
||||
|
||||
class PageOptions(TypedDict):
|
||||
exclude_tags: list[str]
|
||||
include_tags: list[str]
|
||||
wait_time: int
|
||||
include_html: bool
|
||||
only_main_content: bool
|
||||
include_links: bool
|
||||
timeout: int
|
||||
accept_cookies_selector: str
|
||||
locale: str
|
||||
actions: list[Any]
|
||||
|
||||
|
||||
class BaseAPIClient:
|
||||
def __init__(self, api_key, base_url):
|
||||
self.api_key = api_key
|
||||
@ -121,9 +143,9 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
def create_crawl_request(
|
||||
self,
|
||||
url: Union[list, str] | None = None,
|
||||
spider_options: dict | None = None,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
spider_options: SpiderOptions | None = None,
|
||||
page_options: PageOptions | None = None,
|
||||
plugin_options: dict[str, Any] | None = None,
|
||||
):
|
||||
data = {
|
||||
# 'urls': url if isinstance(url, list) else [url],
|
||||
@ -176,8 +198,8 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
def scrape_url(
|
||||
self,
|
||||
url: str,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
page_options: PageOptions | None = None,
|
||||
plugin_options: dict[str, Any] | None = None,
|
||||
sync: bool = True,
|
||||
prefetched: bool = True,
|
||||
):
|
||||
|
||||
@ -2,16 +2,39 @@ from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient
|
||||
|
||||
|
||||
class WatercrawlDocumentData(TypedDict):
|
||||
title: str | None
|
||||
description: str | None
|
||||
source_url: str | None
|
||||
markdown: str | None
|
||||
|
||||
|
||||
class CrawlJobResponse(TypedDict):
|
||||
status: str
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class WatercrawlCrawlStatusResponse(TypedDict):
|
||||
status: str
|
||||
job_id: str | None
|
||||
total: int
|
||||
current: int
|
||||
data: list[WatercrawlDocumentData]
|
||||
time_consuming: float
|
||||
|
||||
|
||||
class WaterCrawlProvider:
|
||||
def __init__(self, api_key, base_url: str | None = None):
|
||||
self.client = WaterCrawlAPIClient(api_key, base_url)
|
||||
|
||||
def crawl_url(self, url, options: dict | Any | None = None):
|
||||
def crawl_url(self, url: str, options: dict[str, Any] | None = None) -> CrawlJobResponse:
|
||||
options = options or {}
|
||||
spider_options = {
|
||||
spider_options: SpiderOptions = {
|
||||
"max_depth": 1,
|
||||
"page_limit": 1,
|
||||
"allowed_domains": [],
|
||||
@ -25,7 +48,7 @@ class WaterCrawlProvider:
|
||||
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
|
||||
|
||||
wait_time = options.get("wait_time", 1000)
|
||||
page_options = {
|
||||
page_options: PageOptions = {
|
||||
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
|
||||
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
|
||||
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
|
||||
@ -41,9 +64,9 @@ class WaterCrawlProvider:
|
||||
|
||||
return {"status": "active", "job_id": result.get("uuid")}
|
||||
|
||||
def get_crawl_status(self, crawl_request_id):
|
||||
def get_crawl_status(self, crawl_request_id: str) -> WatercrawlCrawlStatusResponse:
|
||||
response = self.client.get_crawl_request(crawl_request_id)
|
||||
data = []
|
||||
data: list[WatercrawlDocumentData] = []
|
||||
if response["status"] in ["new", "running"]:
|
||||
status = "active"
|
||||
else:
|
||||
@ -67,7 +90,7 @@ class WaterCrawlProvider:
|
||||
"time_consuming": time_consuming,
|
||||
}
|
||||
|
||||
def get_crawl_url_data(self, job_id, url) -> dict | None:
|
||||
def get_crawl_url_data(self, job_id: str, url: str) -> WatercrawlDocumentData | None:
|
||||
if not job_id:
|
||||
return self.scrape_url(url)
|
||||
|
||||
@ -82,11 +105,11 @@ class WaterCrawlProvider:
|
||||
|
||||
return None
|
||||
|
||||
def scrape_url(self, url: str):
|
||||
def scrape_url(self, url: str) -> WatercrawlDocumentData:
|
||||
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
|
||||
return self._structure_data(response)
|
||||
|
||||
def _structure_data(self, result_object: dict):
|
||||
def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData:
|
||||
if isinstance(result_object.get("result", {}), str):
|
||||
raise ValueError("Invalid result object. Expected a dictionary.")
|
||||
|
||||
@ -98,7 +121,9 @@ class WaterCrawlProvider:
|
||||
"markdown": result_object.get("result", {}).get("markdown"),
|
||||
}
|
||||
|
||||
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
|
||||
def _get_results(
|
||||
self, crawl_request_id: str, query_params: dict | None = None
|
||||
) -> Generator[WatercrawlDocumentData, None, None]:
|
||||
page = 0
|
||||
page_size = 100
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -112,7 +113,7 @@ class WordExtractor(BaseExtractor):
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
@ -140,7 +141,7 @@ class WordExtractor(BaseExtractor):
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
|
||||
@ -33,7 +33,7 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
@ -87,7 +87,7 @@ from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -666,7 +666,11 @@ class DatasetRetrieval:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
return []
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model_config: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
@ -1058,7 +1062,11 @@ class DatasetRetrieval:
|
||||
all_documents.append(document)
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
@ -1132,7 +1140,7 @@ class DatasetRetrieval:
|
||||
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -1141,7 +1149,11 @@ class DatasetRetrieval:
|
||||
}
|
||||
|
||||
for dataset in available_datasets:
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model_config: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Final
|
||||
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
|
||||
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
|
||||
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
|
||||
|
||||
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
|
||||
@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
|
||||
# Get trigger data passed when workflow was triggered
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"event_name": self.node_data.event_name,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
|
||||
@ -245,6 +245,9 @@ _END_STATE = frozenset(
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
|
||||
Values in this enum are persisted as execution metadata and must stay in sync
|
||||
with every node that writes `NodeRunResult.metadata`.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@ -159,6 +159,7 @@ class ErrorHandler:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
@ -198,6 +199,7 @@ class ErrorHandler:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
|
||||
@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
|
||||
from typing_extensions import override
|
||||
|
||||
from dify_graph.context import IExecutionContext
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
@ -65,6 +68,7 @@ class Worker(threading.Thread):
|
||||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
self._current_node_started_at: datetime | None = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
@ -104,18 +108,15 @@ class Worker(threading.Thread):
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._current_node_started_at = None
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
self._event_queue.put(
|
||||
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
|
||||
)
|
||||
self._event_queue.put(error_event)
|
||||
finally:
|
||||
self._current_node_started_at = None
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
@ -136,6 +137,8 @@ class Worker(threading.Thread):
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
@ -149,6 +152,8 @@ class Worker(threading.Thread):
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
@ -177,3 +182,24 @@ class Worker(threading.Thread):
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _build_fallback_failure_event(
|
||||
self, node: Node, error: Exception, *, started_at: datetime | None = None
|
||||
) -> NodeRunFailedEvent:
|
||||
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
|
||||
failure_time = naive_utc_now()
|
||||
error_message = str(error)
|
||||
return NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=error_message,
|
||||
start_at=started_at or failure_time,
|
||||
finished_at=failure_time,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
error_type=type(error).__name__,
|
||||
),
|
||||
)
|
||||
|
||||
@ -36,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
class NodeRunSucceededEvent(GraphNodeEventBase):
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
|
||||
@ -406,11 +406,13 @@ class Node(Generic[NodeDataT]):
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
finished_at = naive_utc_now()
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=str(e),
|
||||
)
|
||||
@ -568,6 +570,7 @@ class Node(Generic[NodeDataT]):
|
||||
return self._node_data
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
finished_at = naive_utc_now()
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
@ -575,6 +578,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=result.error,
|
||||
)
|
||||
@ -584,6 +588,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
case _:
|
||||
@ -606,6 +611,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
finished_at = naive_utc_now()
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
@ -613,6 +619,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
@ -621,6 +628,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
|
||||
@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
future_to_index: dict[
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
float,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, Variable],
|
||||
@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
try:
|
||||
result = future.result()
|
||||
(
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
# The worker computes duration before we replay buffered events here,
|
||||
# so slow downstream consumers don't inflate per-iteration timing.
|
||||
iter_run_map[str(index)] = iteration_duration
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
index: int,
|
||||
item: object,
|
||||
execution_context: "IExecutionContext",
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with execution_context:
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
return (
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
|
||||
@ -256,9 +256,13 @@ def fetch_prompt_messages(
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if prompt_message_content:
|
||||
if not prompt_message_content:
|
||||
continue
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
@ -24,13 +25,11 @@ def handle(sender, **kwargs):
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
document = db.session.scalar(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not document:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
|
||||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
|
||||
)
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
|
||||
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).where(
|
||||
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
|
||||
)
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
|
||||
@ -3,6 +3,7 @@ import json
|
||||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -34,16 +35,15 @@ def load_user_from_request(request_from_flask_login):
|
||||
if admin_api_key and admin_api_key == auth_token:
|
||||
workspace_id = request.headers.get("X-WORKSPACE-ID")
|
||||
if workspace_id:
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == workspace_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role == "owner")
|
||||
.one_or_none()
|
||||
)
|
||||
).one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).filter_by(id=ta.account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == ta.account_id))
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if not end_user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
@ -80,7 +80,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
decoded = PassportService().verify(auth_token)
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
@ -90,11 +90,11 @@ def load_user_from_request(request_from_flask_login):
|
||||
server_code = request.view_args.get("server_code") if request.view_args else None
|
||||
if not server_code:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
app_mcp_server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
|
||||
if not app_mcp_server:
|
||||
raise NotFound("App MCP server not found.")
|
||||
end_user = (
|
||||
db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first()
|
||||
end_user = db.session.scalar(
|
||||
select(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").limit(1)
|
||||
)
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
|
||||
@ -32,7 +32,7 @@ class OpenDALStorage(BaseStorage):
|
||||
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
|
||||
|
||||
if scheme == "fs":
|
||||
root = kwargs.get("root", "storage")
|
||||
root = kwargs.setdefault("root", "storage")
|
||||
Path(root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)
|
||||
|
||||
@ -424,13 +424,11 @@ def _build_from_datasource_file(
|
||||
datasource_file_id = mapping.get("datasource_file_id")
|
||||
if not datasource_file_id:
|
||||
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
|
||||
datasource_file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
datasource_file = db.session.scalar(
|
||||
select(UploadFile).where(
|
||||
UploadFile.id == datasource_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if datasource_file is None:
|
||||
|
||||
@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
||||
@ -23,13 +23,22 @@ from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus
|
||||
from .enums import (
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
MessageChainType,
|
||||
MessageStatus,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
@ -925,8 +934,11 @@ class ExporleBanner(TypeBase):
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
||||
status: Mapped[BannerStatus] = mapped_column(
|
||||
EnumText(BannerStatus, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'enabled'::character varying"),
|
||||
default=BannerStatus.ENABLED,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
@ -2097,7 +2109,7 @@ class UploadFile(Base):
|
||||
# The `server_default` serves as a fallback mechanism.
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
@ -2141,7 +2153,7 @@ class UploadFile(Base):
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
storage_type: str,
|
||||
storage_type: StorageType,
|
||||
key: str,
|
||||
name: str,
|
||||
size: int,
|
||||
@ -2206,7 +2218,7 @@ class MessageChain(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
|
||||
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -210,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase):
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@ -22,14 +22,14 @@ from sqlalchemy import (
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
@ -936,8 +936,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
|
||||
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
|
||||
elif (
|
||||
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
|
||||
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
|
||||
):
|
||||
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.1"
|
||||
version = "1.13.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[pytest]
|
||||
pythonpath = .
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib --cov-branch --cov-report=xml
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||
|
||||
@ -3,6 +3,7 @@ import math
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
import app
|
||||
from core.helper.marketplace import fetch_global_plugin_manifest
|
||||
@ -28,17 +29,15 @@ def check_upgradable_plugin_task():
|
||||
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
|
||||
click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green"))
|
||||
|
||||
strategies = (
|
||||
db.session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(
|
||||
strategies = db.session.scalars(
|
||||
select(TenantPluginAutoUpgradeStrategy).where(
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
|
||||
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
|
||||
TenantPluginAutoUpgradeStrategy.strategy_setting
|
||||
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
total_strategies = len(strategies)
|
||||
click.echo(click.style(f"Total strategies: {total_strategies}", fg="green"))
|
||||
|
||||
@ -2,7 +2,7 @@ import datetime
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
@ -19,14 +19,12 @@ def clean_embedding_cache_task():
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
while True:
|
||||
try:
|
||||
embedding_ids = (
|
||||
db.session.query(Embedding.id)
|
||||
embedding_ids = db.session.scalars(
|
||||
select(Embedding.id)
|
||||
.where(Embedding.created_at < thirty_days_ago)
|
||||
.order_by(Embedding.created_at.desc())
|
||||
.limit(100)
|
||||
.all()
|
||||
)
|
||||
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
|
||||
).all()
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if embedding_ids:
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
from typing import TypedDict
|
||||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
@ -51,7 +51,7 @@ def clean_unused_datasets_task():
|
||||
try:
|
||||
# Subquery for counting new documents
|
||||
document_subquery_new = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
select(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
@ -64,7 +64,7 @@ def clean_unused_datasets_task():
|
||||
|
||||
# Subquery for counting old documents
|
||||
document_subquery_old = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
select(Document.dataset_id, func.count(Document.id).label("document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
@ -142,8 +142,8 @@ def clean_unused_datasets_task():
|
||||
index_processor.clean(dataset, None)
|
||||
|
||||
# Update document
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
|
||||
{Document.enabled: False}
|
||||
db.session.execute(
|
||||
update(Document).where(Document.dataset_id == dataset.id).values(enabled=False)
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import func, select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
@ -20,7 +21,7 @@ def create_tidb_serverless_task():
|
||||
try:
|
||||
# check the number of idle tidb serverless
|
||||
idle_tidb_serverless_number = (
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
|
||||
db.session.scalar(select(func.count(TidbAuthBinding.id)).where(TidbAuthBinding.active == False)) or 0
|
||||
)
|
||||
if idle_tidb_serverless_number >= tidb_serverless_number:
|
||||
break
|
||||
|
||||
@ -49,16 +49,18 @@ def mail_clean_document_notify_task():
|
||||
if plan != CloudPlan.SANDBOX:
|
||||
knowledge_details = []
|
||||
# check tenant
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
# check current owner
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
)
|
||||
if not current_owner_join:
|
||||
continue
|
||||
account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
|
||||
if not account:
|
||||
continue
|
||||
|
||||
@ -71,7 +73,7 @@ def mail_clean_document_notify_task():
|
||||
)
|
||||
|
||||
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
|
||||
if dataset:
|
||||
document_count = len(document_ids)
|
||||
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
||||
|
||||
@ -23,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account
|
||||
@ -93,7 +94,7 @@ class FileService:
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_tenant_id or "",
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=filename,
|
||||
size=file_size,
|
||||
@ -152,7 +153,7 @@ class FileService:
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=file_key,
|
||||
name=text_name,
|
||||
size=len(text),
|
||||
|
||||
@ -9,7 +9,7 @@ import httpx
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData
|
||||
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -216,8 +216,10 @@ class WebsiteService:
|
||||
"max_depth": request.options.max_depth,
|
||||
"use_sitemap": request.options.use_sitemap,
|
||||
}
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
return dict(
|
||||
WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -270,13 +272,13 @@ class WebsiteService:
|
||||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
"status": result.get("status", "active"),
|
||||
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data: dict[str, Any] = {
|
||||
"status": result["status"],
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
"total": result["total"] or 0,
|
||||
"current": result["current"] or 0,
|
||||
"data": result["data"],
|
||||
}
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
@ -289,8 +291,8 @@ class WebsiteService:
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
@ -343,7 +345,7 @@ class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
crawl_data: list[FirecrawlDocumentData] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
stored_data = storage.load_once(file_key)
|
||||
@ -352,19 +354,22 @@ class WebsiteService:
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
if result["status"] != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
crawl_data = result.get("data")
|
||||
crawl_data = result["data"]
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
if item["source_url"] == url:
|
||||
return dict(item)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
def _get_watercrawl_url_data(
|
||||
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
result = WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
return dict(result) if result is not None else None
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
@ -416,8 +421,8 @@ class WebsiteService:
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
params = {"onlyMainContent": request.only_main_content}
|
||||
return firecrawl_app.scrape_url(url=request.url, params=params)
|
||||
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url))
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
@ -288,7 +289,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
|
||||
@ -13,6 +13,7 @@ from dify_graph.variables.types import SegmentType
|
||||
from dify_graph.variables.variables import StringVariable
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models.enums import CreatorUserRole
|
||||
@ -347,7 +348,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Create an upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_offload_{uuid.uuid4()}.json",
|
||||
name="test_offload.json",
|
||||
size=len(content_bytes),
|
||||
@ -450,7 +451,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Create upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_integration_{uuid.uuid4()}.txt",
|
||||
name="test_integration.txt",
|
||||
size=len(content_bytes),
|
||||
|
||||
@ -6,6 +6,7 @@ from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
@ -197,7 +198,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
@ -210,7 +211,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
@ -430,7 +431,7 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
@ -443,7 +444,7 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
@ -289,7 +290,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
|
||||
@ -13,6 +13,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
@ -198,7 +199,7 @@ class DocumentStatusTestDataFactory:
|
||||
"""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
|
||||
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
|
||||
@ -83,7 +84,7 @@ def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, n
|
||||
"""Persist an upload file row referenced by document.data_source_info."""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
|
||||
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
@ -140,7 +141,7 @@ class TestFileService:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"upload_files/test/{fake.uuid4()}.txt",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
|
||||
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import DataSourceType
|
||||
from models.enums import DataSourceType, MessageChainType
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@ -236,7 +236,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
# MessageChain
|
||||
chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="system",
|
||||
type=MessageChainType.SYSTEM,
|
||||
input=json.dumps({"test": "input"}),
|
||||
output=json.dumps({"test": "output"}),
|
||||
)
|
||||
|
||||
@ -48,41 +48,42 @@ class TestToolTransformService:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
|
||||
provider_type="api",
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
elif provider_type == "builtin":
|
||||
provider = BuiltinToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon="🔧",
|
||||
icon_dark="🔧",
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
provider="test_provider",
|
||||
credential_type="api_key",
|
||||
credentials={"api_key": "test_key"},
|
||||
encrypted_credentials='{"api_key": "test_key"}',
|
||||
)
|
||||
elif provider_type == "workflow":
|
||||
provider = WorkflowToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
workflow_id="test_workflow_id",
|
||||
app_id="test_workflow_id",
|
||||
label="Test Workflow",
|
||||
version="1.0.0",
|
||||
parameter_configuration="[]",
|
||||
)
|
||||
elif provider_type == "mcp":
|
||||
provider = MCPToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
server_url="https://mcp.example.com",
|
||||
server_url_hash="test_server_url_hash",
|
||||
server_identifier="test_server",
|
||||
tools='[{"name": "test_tool", "description": "Test tool"}]',
|
||||
authed=True,
|
||||
|
||||
@ -13,6 +13,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -209,7 +210,7 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=account.current_tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
|
||||
@ -19,6 +19,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
@ -203,7 +204,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
|
||||
@ -18,6 +18,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -254,7 +255,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{fake.file_name()}",
|
||||
name=fake.file_name(),
|
||||
size=1024,
|
||||
@ -925,7 +926,7 @@ class TestCleanDatasetTask:
|
||||
special_filename = f"test_file_{special_content}.txt"
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test_files/{special_filename}",
|
||||
name=special_filename,
|
||||
size=1024,
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
@ -78,7 +79,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id:
|
||||
for i in range(count):
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"test/file-{uuid.uuid4()}-{i}.json",
|
||||
name=f"file-{i}.json",
|
||||
size=1024 + i,
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
|
||||
|
||||
class TestOpenDALFsDefaultRoot:
|
||||
"""Test that OpenDALStorage with scheme='fs' works correctly when no root is provided."""
|
||||
|
||||
def test_fs_without_root_uses_default(self, tmp_path, monkeypatch):
|
||||
"""When no root is specified, the default 'storage' should be used and passed to the Operator."""
|
||||
# Change to tmp_path so the default "storage" dir is created there
|
||||
monkeypatch.chdir(tmp_path)
|
||||
# Ensure no OPENDAL_FS_ROOT env var is set
|
||||
monkeypatch.delenv("OPENDAL_FS_ROOT", raising=False)
|
||||
|
||||
storage = OpenDALStorage(scheme="fs")
|
||||
|
||||
# The default directory should have been created
|
||||
assert (tmp_path / "storage").is_dir()
|
||||
# The storage should be functional
|
||||
storage.save("test_default_root.txt", b"hello")
|
||||
assert storage.exists("test_default_root.txt")
|
||||
assert storage.load_once("test_default_root.txt") == b"hello"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_default_root.txt")
|
||||
|
||||
def test_fs_with_explicit_root(self, tmp_path):
|
||||
"""When root is explicitly provided, it should be used."""
|
||||
custom_root = str(tmp_path / "custom_storage")
|
||||
storage = OpenDALStorage(scheme="fs", root=custom_root)
|
||||
|
||||
assert Path(custom_root).is_dir()
|
||||
storage.save("test_explicit_root.txt", b"world")
|
||||
assert storage.exists("test_explicit_root.txt")
|
||||
assert storage.load_once("test_explicit_root.txt") == b"world"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_explicit_root.txt")
|
||||
|
||||
def test_fs_with_env_var_root(self, tmp_path, monkeypatch):
|
||||
"""When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs."""
|
||||
env_root = str(tmp_path / "env_storage")
|
||||
monkeypatch.setenv("OPENDAL_FS_ROOT", env_root)
|
||||
# Ensure .env file doesn't interfere
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
storage = OpenDALStorage(scheme="fs")
|
||||
|
||||
assert Path(env_root).is_dir()
|
||||
storage.save("test_env_root.txt", b"env_data")
|
||||
assert storage.exists("test_env_root.txt")
|
||||
assert storage.load_once("test_env_root.txt") == b"env_data"
|
||||
|
||||
# Cleanup
|
||||
storage.delete("test_env_root.txt")
|
||||
@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import (
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
@ -1121,7 +1122,7 @@ class TestDatasetIndexingEstimateApi:
|
||||
def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="key",
|
||||
name="name.txt",
|
||||
size=1,
|
||||
|
||||
@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.explore.banner as banner_module
|
||||
from models.enums import BannerStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -20,7 +21,7 @@ class TestBannerApi:
|
||||
banner.content = {"text": "hello"}
|
||||
banner.link = "https://example.com"
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = datetime(2024, 1, 1)
|
||||
|
||||
query = MagicMock()
|
||||
@ -54,7 +55,7 @@ class TestBannerApi:
|
||||
banner.content = {"text": "fallback"}
|
||||
banner.link = None
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = None
|
||||
|
||||
query = MagicMock()
|
||||
|
||||
@ -39,14 +39,21 @@ class TestHitTestingPayload:
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all optional fields."""
|
||||
retrieval_model_data = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
"top_k": 5,
|
||||
}
|
||||
payload = HitTestingPayload(
|
||||
query="test query",
|
||||
retrieval_model={"top_k": 5},
|
||||
retrieval_model=retrieval_model_data,
|
||||
external_retrieval_model={"provider": "openai"},
|
||||
attachment_ids=["att_1", "att_2"],
|
||||
)
|
||||
assert payload.query == "test query"
|
||||
assert payload.retrieval_model == {"top_k": 5}
|
||||
assert payload.retrieval_model is not None
|
||||
assert payload.retrieval_model.top_k == 5
|
||||
assert payload.external_retrieval_model == {"provider": "openai"}
|
||||
assert payload.attachment_ids == ["att_1", "att_2"]
|
||||
|
||||
@ -134,7 +141,13 @@ class TestHitTestingApiPost:
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8}
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": True,
|
||||
"top_k": 10,
|
||||
"score_threshold": 0.8,
|
||||
}
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
@ -152,7 +165,11 @@ class TestHitTestingApiPost:
|
||||
|
||||
assert response["query"] == "complex query"
|
||||
call_kwargs = mock_hit_svc.retrieve.call_args
|
||||
assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model
|
||||
# retrieval_model is serialized via model_dump, verify key fields
|
||||
passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model")
|
||||
assert passed_retrieval_model is not None
|
||||
assert passed_retrieval_model["search_method"] == "semantic_search"
|
||||
assert passed_retrieval_model["top_k"] == 10
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
|
||||
@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_prefers_event_finished_at(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Finished timestamps should come from the event, not delayed queue processing time."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.common.workflow_response_converter.naive_utc_now",
|
||||
lambda: delayed_processing_time,
|
||||
)
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.CODE,
|
||||
node_execution_id="node-exec-1",
|
||||
start_at=start_at,
|
||||
finished_at=finished_at,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.elapsed_time == 2.0
|
||||
assert response.data.finished_at == int(finished_at.timestamp())
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.workflow.layers.persistence import (
|
||||
PersistenceWorkflowInfo,
|
||||
WorkflowPersistenceLayer,
|
||||
_NodeRuntimeSnapshot,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_layer() -> WorkflowPersistenceLayer:
|
||||
application_generate_entity = Mock()
|
||||
application_generate_entity.inputs = {}
|
||||
|
||||
return WorkflowPersistenceLayer(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version="1",
|
||||
graph_data={},
|
||||
),
|
||||
workflow_execution_repository=Mock(),
|
||||
workflow_node_execution_repository=Mock(),
|
||||
)
|
||||
|
||||
|
||||
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
layer = _build_layer()
|
||||
node_execution = Mock()
|
||||
node_execution.id = "node-exec-1"
|
||||
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
node_execution.update_from_mapping = Mock()
|
||||
|
||||
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
|
||||
node_id="node-id",
|
||||
title="LLM",
|
||||
predecessor_node_id=None,
|
||||
iteration_id="iter-1",
|
||||
loop_id=None,
|
||||
created_at=node_execution.created_at,
|
||||
)
|
||||
|
||||
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
|
||||
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
|
||||
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
|
||||
|
||||
layer._update_node_execution(
|
||||
node_execution,
|
||||
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event_finished_at,
|
||||
)
|
||||
|
||||
assert node_execution.finished_at == event_finished_at
|
||||
assert node_execution.elapsed_time == 2.0
|
||||
@ -166,6 +166,7 @@ class TestDatasourceFileManager:
|
||||
# Setup
|
||||
mock_guess_ext.return_value = None # Cannot guess
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
|
||||
@ -410,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() -
|
||||
|
||||
configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session)
|
||||
|
||||
assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value
|
||||
assert existing_record.preferred_provider_type == ProviderType.SYSTEM
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@ -104,10 +104,11 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_map_known_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("map error"))
|
||||
mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"}))
|
||||
|
||||
assert app.map("https://example.com") == {}
|
||||
with pytest.raises(Exception, match="map error"):
|
||||
app.map("https://example.com")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_map_unknown_error_raises(self, mocker: MockerFixture):
|
||||
@ -177,10 +178,11 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl error"))
|
||||
mocker.patch("httpx.get", return_value=_response(500, {"error": "server"}))
|
||||
|
||||
assert app.check_crawl_status("job-1") == {}
|
||||
with pytest.raises(Exception, match="crawl error"):
|
||||
app.check_crawl_status("job-1")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture):
|
||||
@ -272,9 +274,10 @@ class TestFirecrawlApp:
|
||||
|
||||
def test_search_known_http_error(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error")
|
||||
mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("search error"))
|
||||
mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"}))
|
||||
assert app.search("python") == {}
|
||||
with pytest.raises(Exception, match="search error"):
|
||||
app.search("python")
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
def test_search_unknown_http_error(self, mocker: MockerFixture):
|
||||
|
||||
145
api/tests/unit_tests/core/workflow/graph_engine/test_worker.py
Normal file
145
api/tests/unit_tests/core/workflow/graph_engine/test_worker.py
Normal file
@ -0,0 +1,145 @@
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from dify_graph.graph_engine.worker import Worker
|
||||
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
|
||||
|
||||
|
||||
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
|
||||
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=InMemoryReadyQueue(),
|
||||
event_queue=queue.Queue(),
|
||||
graph=MagicMock(),
|
||||
layers=[],
|
||||
)
|
||||
node = SimpleNamespace(
|
||||
execution_id="exec-1",
|
||||
id="node-1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
)
|
||||
|
||||
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
|
||||
|
||||
assert event.start_at == fixed_time
|
||||
assert event.finished_at == fixed_time
|
||||
assert event.error == "boom"
|
||||
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert event.node_run_result.error == "boom"
|
||||
assert event.node_run_result.error_type == "RuntimeError"
|
||||
|
||||
|
||||
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
|
||||
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
failure_time = start_at + timedelta(seconds=5)
|
||||
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
|
||||
|
||||
class FakeNode:
|
||||
execution_id = "exec-1"
|
||||
id = "node-1"
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
|
||||
yield NodeRunStartedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
node_title="LLM",
|
||||
start_at=start_at,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=MagicMock(),
|
||||
event_queue=MagicMock(),
|
||||
graph=MagicMock(nodes={"node-1": FakeNode()}),
|
||||
layers=[],
|
||||
)
|
||||
|
||||
worker._ready_queue.get.side_effect = ["node-1"]
|
||||
|
||||
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
|
||||
captured_events.append(event)
|
||||
if len(captured_events) == 1:
|
||||
raise RuntimeError("queue boom")
|
||||
worker.stop()
|
||||
|
||||
worker._event_queue.put.side_effect = put_side_effect
|
||||
|
||||
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
|
||||
worker.run()
|
||||
|
||||
fallback_event = captured_events[-1]
|
||||
|
||||
assert isinstance(fallback_event, NodeRunFailedEvent)
|
||||
assert fallback_event.start_at == start_at
|
||||
assert fallback_event.finished_at == failure_time
|
||||
assert fallback_event.error == "queue boom"
|
||||
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
|
||||
|
||||
def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None:
|
||||
parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
|
||||
child_start = parent_start + timedelta(seconds=3)
|
||||
failure_time = parent_start + timedelta(seconds=5)
|
||||
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
|
||||
|
||||
class FakeIterationNode:
|
||||
execution_id = "iteration-exec"
|
||||
id = "iteration-node"
|
||||
node_type = BuiltinNodeTypes.ITERATION
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
|
||||
yield NodeRunStartedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
node_title="Iteration",
|
||||
start_at=parent_start,
|
||||
)
|
||||
yield NodeRunStartedEvent(
|
||||
id="child-exec",
|
||||
node_id="child-node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
node_title="LLM",
|
||||
start_at=child_start,
|
||||
in_iteration_id=self.id,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=MagicMock(),
|
||||
event_queue=MagicMock(),
|
||||
graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}),
|
||||
layers=[],
|
||||
)
|
||||
|
||||
worker._ready_queue.get.side_effect = ["iteration-node"]
|
||||
|
||||
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
|
||||
captured_events.append(event)
|
||||
if len(captured_events) == 2:
|
||||
raise RuntimeError("queue boom")
|
||||
worker.stop()
|
||||
|
||||
worker._event_queue.put.side_effect = put_side_effect
|
||||
|
||||
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
|
||||
worker.run()
|
||||
|
||||
fallback_event = captured_events[-1]
|
||||
|
||||
assert isinstance(fallback_event, NodeRunFailedEvent)
|
||||
assert fallback_event.start_at == parent_start
|
||||
assert fallback_event.finished_at == failure_time
|
||||
@ -0,0 +1,63 @@
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.graph_events import NodeRunSucceededEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from dify_graph.nodes.iteration.iteration_node import IterationNode
|
||||
|
||||
|
||||
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Parallel Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["iteration", "output"],
|
||||
is_parallel=True,
|
||||
parallel_nums=2,
|
||||
error_handle_mode=ErrorHandleMode.TERMINATED,
|
||||
)
|
||||
node._capture_execution_context = lambda: nullcontext()
|
||||
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
|
||||
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
|
||||
|
||||
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
|
||||
return (
|
||||
0.1 + (index * 0.1),
|
||||
[
|
||||
NodeRunSucceededEvent(
|
||||
id=f"exec-{index}",
|
||||
node_id=f"llm-{index}",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
start_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
),
|
||||
],
|
||||
f"output-{item}",
|
||||
{},
|
||||
LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
|
||||
|
||||
outputs: list[object] = []
|
||||
iter_run_map: dict[str, float] = {}
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
generator = node._execute_parallel_iterations(
|
||||
iterator_list_value=["a", "b"],
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
for _ in generator:
|
||||
# Simulate a slow consumer replaying buffered events.
|
||||
time.sleep(0.02)
|
||||
|
||||
assert outputs == ["output-a", "output-b"]
|
||||
assert iter_run_map["0"] == pytest.approx(0.1)
|
||||
assert iter_run_map["1"] == pytest.approx(0.2)
|
||||
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
@ -0,0 +1,106 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
prompt_template = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="You are a classifier.",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
|
||||
return_value=mock.MagicMock(features=[]),
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
|
||||
return_value=[SystemPromptMessage(content=content)],
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=["END"],
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=None,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
]
|
||||
)
|
||||
]
|
||||
@ -0,0 +1,63 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = build_test_graph_init_params(
|
||||
graph_config=graph_config,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", files=[]),
|
||||
user_inputs={"payload": "value"},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def _build_node_config() -> NodeConfigDict:
|
||||
return NodeConfigDictAdapter.validate_python(
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": TRIGGER_PLUGIN_NODE_TYPE,
|
||||
"title": "Trigger Event",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"subscription_id": "subscription-id",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
"event_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
id="node-1",
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
}
|
||||
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
@ -0,0 +1,19 @@
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events.base import NodeRunResult
|
||||
|
||||
|
||||
def test_node_run_result_accepts_trigger_info_metadata() -> None:
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def test_creator_user_role_missing_maps_hyphen_to_enum():
|
||||
# given an alias with hyphen
|
||||
value = "end-user"
|
||||
|
||||
# when converting to enum (invokes StrEnum._missing_ override)
|
||||
role = CreatorUserRole(value)
|
||||
|
||||
# then it should map to END_USER
|
||||
assert role is CreatorUserRole.END_USER
|
||||
|
||||
|
||||
def test_creator_user_role_missing_raises_for_unknown():
|
||||
with pytest.raises(ValueError):
|
||||
CreatorUserRole("unknown")
|
||||
@ -443,7 +443,7 @@ def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monk
|
||||
|
||||
def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed"}
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 1, "current": 1, "data": []}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
redis_mock = MagicMock()
|
||||
|
||||
65
api/uv.lock
generated
65
api/uv.lock
generated
@ -1533,7 +1533,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.13.1"
|
||||
version = "1.13.2"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
@ -5405,11 +5405,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -7248,30 +7248,43 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "ujson"
|
||||
version = "5.9.0"
|
||||
version = "5.12.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6e/54/6f2bdac7117e89a47de4511c9f01732a283457ab1bf856e1e51aa861619e/ujson-5.9.0.tar.gz", hash = "sha256:89cc92e73d5501b8a7f48575eeb14ad27156ad092c2e9fc7e3cf949f07e75532", size = 7154214, upload-time = "2023-12-10T22:50:34.812Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cb/3e/c35530c5ffc25b71c59ae0cd7b8f99df37313daa162ce1e2f7925f7c2877/ujson-5.12.0.tar.gz", hash = "sha256:14b2e1eb528d77bc0f4c5bd1a7ebc05e02b5b41beefb7e8567c9675b8b13bcf4", size = 7158451, upload-time = "2026-03-11T22:19:30.397Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/ca/ae3a6ca5b4f82ce654d6ac3dde5e59520537e20939592061ba506f4e569a/ujson-5.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b23bbb46334ce51ddb5dded60c662fbf7bb74a37b8f87221c5b0fec1ec6454b", size = 57753, upload-time = "2023-12-10T22:49:03.939Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/5f/c27fa9a1562c96d978c39852b48063c3ca480758f3088dcfc0f3b09f8e93/ujson-5.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6974b3a7c17bbf829e6c3bfdc5823c67922e44ff169851a755eab79a3dd31ec0", size = 54092, upload-time = "2023-12-10T22:49:05.194Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/f3/1431713de9e5992e5e33ba459b4de28f83904233958855d27da820a101f9/ujson-5.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5964ea916edfe24af1f4cc68488448fbb1ec27a3ddcddc2b236da575c12c8ae", size = 51675, upload-time = "2023-12-10T22:49:06.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/93/de6fff3ae06351f3b1c372f675fe69bc180f93d237c9e496c05802173dd6/ujson-5.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba7cac47dd65ff88571eceeff48bf30ed5eb9c67b34b88cb22869b7aa19600d", size = 53246, upload-time = "2023-12-10T22:49:07.691Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/73/db509fe1d7da62a15c0769c398cec66bdfc61a8bdffaf7dfa9d973e3d65c/ujson-5.9.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6bbd91a151a8f3358c29355a491e915eb203f607267a25e6ab10531b3b157c5e", size = 58182, upload-time = "2023-12-10T22:49:08.89Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/a8/6be607fa3e1fa3e1c9b53f5de5acad33b073b6cc9145803e00bcafa729a8/ujson-5.9.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:829a69d451a49c0de14a9fecb2a2d544a9b2c884c2b542adb243b683a6f15908", size = 584493, upload-time = "2023-12-10T22:49:11.043Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/c7/33822c2f1a8175e841e2bc378ffb2c1109ce9280f14cedb1b2fa0caf3145/ujson-5.9.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a807ae73c46ad5db161a7e883eec0fbe1bebc6a54890152ccc63072c4884823b", size = 656038, upload-time = "2023-12-10T22:49:12.651Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/b8/5309fbb299d5fcac12bbf3db20896db5178392904abe6b992da233dc69d6/ujson-5.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8fc2aa18b13d97b3c8ccecdf1a3c405f411a6e96adeee94233058c44ff92617d", size = 597643, upload-time = "2023-12-10T22:49:14.883Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/64/7b63043b95dd78feed401b9973958af62645a6d19b72b6e83d1ea5af07e0/ujson-5.9.0-cp311-cp311-win32.whl", hash = "sha256:70e06849dfeb2548be48fdd3ceb53300640bc8100c379d6e19d78045e9c26120", size = 38342, upload-time = "2023-12-10T22:49:16.854Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/13/a3cd1fc3a1126d30b558b6235c05e2d26eeaacba4979ee2fd2b5745c136d/ujson-5.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:7309d063cd392811acc49b5016728a5e1b46ab9907d321ebbe1c2156bc3c0b99", size = 41923, upload-time = "2023-12-10T22:49:17.983Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/7e/c37fca6cd924931fa62d615cdbf5921f34481085705271696eff38b38867/ujson-5.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:20509a8c9f775b3a511e308bbe0b72897ba6b800767a7c90c5cca59d20d7c42c", size = 57834, upload-time = "2023-12-10T22:49:19.799Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/44/2753e902ee19bf6ccaf0bda02f1f0037f92a9769a5d31319905e3de645b4/ujson-5.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b28407cfe315bd1b34f1ebe65d3bd735d6b36d409b334100be8cdffae2177b2f", size = 54119, upload-time = "2023-12-10T22:49:21.039Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/06/2317433e394450bc44afe32b6c39d5a51014da4c6f6cfc2ae7bf7b4a2922/ujson-5.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d302bd17989b6bd90d49bade66943c78f9e3670407dbc53ebcf61271cadc399", size = 51658, upload-time = "2023-12-10T22:49:22.494Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/3a/2acf0da085d96953580b46941504aa3c91a1dd38701b9e9bfa43e2803467/ujson-5.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f21315f51e0db8ee245e33a649dd2d9dce0594522de6f278d62f15f998e050e", size = 53370, upload-time = "2023-12-10T22:49:24.045Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/32/737e6c4b1841720f88ae88ec91f582dc21174bd40742739e1fa16a0c9ffa/ujson-5.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5635b78b636a54a86fdbf6f027e461aa6c6b948363bdf8d4fbb56a42b7388320", size = 58278, upload-time = "2023-12-10T22:49:25.261Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/dc/3fda97f1ad070ccf2af597fb67dde358bc698ffecebe3bc77991d60e4fe5/ujson-5.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82b5a56609f1235d72835ee109163c7041b30920d70fe7dac9176c64df87c164", size = 584418, upload-time = "2023-12-10T22:49:27.573Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/57/e4083d774fcd8ff3089c0ff19c424abe33f23e72c6578a8172bf65131992/ujson-5.9.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5ca35f484622fd208f55041b042d9d94f3b2c9c5add4e9af5ee9946d2d30db01", size = 656126, upload-time = "2023-12-10T22:49:29.509Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/c3/8c6d5f6506ca9fcedd5a211e30a7d5ee053dc05caf23dae650e1f897effb/ujson-5.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:829b824953ebad76d46e4ae709e940bb229e8999e40881338b3cc94c771b876c", size = 597795, upload-time = "2023-12-10T22:49:31.029Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/5a/a231f0cd305a34cf2d16930304132db3a7a8c3997b367dd38fc8f8dfae36/ujson-5.9.0-cp312-cp312-win32.whl", hash = "sha256:25fa46e4ff0a2deecbcf7100af3a5d70090b461906f2299506485ff31d9ec437", size = 38495, upload-time = "2023-12-10T22:49:33.2Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/b7/18b841b44760ed298acdb150608dccdc045c41655e0bae4441f29bcab872/ujson-5.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:60718f1720a61560618eff3b56fd517d107518d3c0160ca7a5a66ac949c6cf1c", size = 42088, upload-time = "2023-12-10T22:49:34.921Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/22/fd22e2f6766bae934d3050517ca47d463016bd8688508d1ecc1baa18a7ad/ujson-5.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58a11cb49482f1a095a2bd9a1d81dd7c8fb5d2357f959ece85db4e46a825fd00", size = 56139, upload-time = "2026-03-11T22:18:04.591Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/fd/6839adff4fc0164cbcecafa2857ba08a6eaeedd7e098d6713cb899a91383/ujson-5.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9b3cf13facf6f77c283af0e1713e5e8c47a0fe295af81326cb3cb4380212e797", size = 53836, upload-time = "2026-03-11T22:18:05.662Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/b0/0c19faac62d68ceeffa83a08dc3d71b8462cf5064d0e7e0b15ba19898dad/ujson-5.12.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb94245a715b4d6e24689de12772b85329a1f9946cbf6187923a64ecdea39e65", size = 57851, upload-time = "2026-03-11T22:18:06.744Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/f6/e7fd283788de73b86e99e08256726bb385923249c21dcd306e59d532a1a1/ujson-5.12.0-cp311-cp311-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:0fe6b8b8968e11dd9b2348bd508f0f57cf49ab3512064b36bc4117328218718e", size = 59906, upload-time = "2026-03-11T22:18:07.791Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/3a/b100735a2b43ee6e8fe4c883768e362f53576f964d4ea841991060aeaf35/ujson-5.12.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:89e302abd3749f6d6699691747969a5d85f7c73081d5ed7e2624c7bd9721a2ab", size = 57409, upload-time = "2026-03-11T22:18:08.79Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/fa/f97cc20c99ca304662191b883ae13ae02912ca7244710016ba0cb8a5be34/ujson-5.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0727363b05ab05ee737a28f6200dc4078bce6b0508e10bd8aab507995a15df61", size = 1037339, upload-time = "2026-03-11T22:18:10.424Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/7a/53ddeda0ffe1420db2f9999897b3cbb920fbcff1849d1f22b196d0f34785/ujson-5.12.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b62cb9a7501e1f5c9ffe190485501349c33e8862dde4377df774e40b8166871f", size = 1196625, upload-time = "2026-03-11T22:18:11.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/1a/4c64a6bef522e9baf195dd5be151bc815cd4896c50c6e2489599edcda85f/ujson-5.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a6ec5bf6bc361f2f0f9644907a36ce527715b488988a8df534120e5c34eeda94", size = 1089669, upload-time = "2026-03-11T22:18:13.343Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/11/8ccb109f5777ec0d9fb826695a9e2ac36ae94c1949fc8b1e4d23a5bd067a/ujson-5.12.0-cp311-cp311-win32.whl", hash = "sha256:006428d3813b87477d72d306c40c09f898a41b968e57b15a7d88454ecc42a3fb", size = 39648, upload-time = "2026-03-11T22:18:14.785Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/e3/87fc4c27b20d5125cff7ce52d17ea7698b22b74426da0df238e3efcb0cf2/ujson-5.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:40aa43a7a3a8d2f05e79900858053d697a88a605e3887be178b43acbcd781161", size = 43876, upload-time = "2026-03-11T22:18:15.768Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/21/324f0548a8c8c48e3e222eaed15fb6d48c796593002b206b4a28a89e445f/ujson-5.12.0-cp311-cp311-win_arm64.whl", hash = "sha256:561f89cc82deeae82e37d4a4764184926fb432f740a9691563a391b13f7339a4", size = 38553, upload-time = "2026-03-11T22:18:17.251Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/f6/ac763d2108d28f3a40bb3ae7d2fafab52ca31b36c2908a4ad02cd3ceba2a/ujson-5.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:09b4beff9cc91d445d5818632907b85fb06943b61cb346919ce202668bf6794a", size = 56326, upload-time = "2026-03-11T22:18:18.467Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/46/d0b3af64dcdc549f9996521c8be6d860ac843a18a190ffc8affeb7259687/ujson-5.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ca0c7ce828bb76ab78b3991904b477c2fd0f711d7815c252d1ef28ff9450b052", size = 53910, upload-time = "2026-03-11T22:18:19.502Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/10/853c723bcabc3e9825a079019055fc99e71b85c6bae600607a2b9d31d18d/ujson-5.12.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2d79c6635ccffcbfc1d5c045874ba36b594589be81d50d43472570bb8de9c57", size = 57754, upload-time = "2026-03-11T22:18:20.874Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/c6/6e024830d988f521f144ead641981c1f7a82c17ad1927c22de3242565f5c/ujson-5.12.0-cp312-cp312-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:7e07f6f644d2c44d53b7a320a084eef98063651912c1b9449b5f45fcbdc6ccd2", size = 59936, upload-time = "2026-03-11T22:18:21.924Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/c9/c5f236af5abe06b720b40b88819d00d10182d2247b1664e487b3ed9229cf/ujson-5.12.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:085b6ce182cdd6657481c7c4003a417e0655c4f6e58b76f26ee18f0ae21db827", size = 57463, upload-time = "2026-03-11T22:18:22.924Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/04/41342d9ef68e793a87d84e4531a150c2b682f3bcedfe59a7a5e3f73e9213/ujson-5.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:16b4fe9c97dc605f5e1887a9e1224287291e35c56cbc379f8aa44b6b7bcfe2bb", size = 1037239, upload-time = "2026-03-11T22:18:24.04Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/81/dc2b7617d5812670d4ff4a42f6dd77926430ee52df0dedb2aec7990b2034/ujson-5.12.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0d2e8db5ade3736a163906154ca686203acc7d1d30736cbf577c730d13653d84", size = 1196713, upload-time = "2026-03-11T22:18:25.391Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/9c/80acff0504f92459ed69e80a176286e32ca0147ac6a8252cd0659aad3227/ujson-5.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:93bc91fdadcf046da37a214eaa714574e7e9b1913568e93bb09527b2ceb7f759", size = 1089742, upload-time = "2026-03-11T22:18:26.738Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/f0/123ffaac17e45ef2b915e3e3303f8f4ea78bb8d42afad828844e08622b1e/ujson-5.12.0-cp312-cp312-win32.whl", hash = "sha256:2a248750abce1c76fbd11b2e1d88b95401e72819295c3b851ec73399d6849b3d", size = 39773, upload-time = "2026-03-11T22:18:28.244Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/20/f3bd2b069c242c2b22a69e033bfe224d1d15d3649e6cd7cc7085bb1412ff/ujson-5.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:1b5c6ceb65fecd28a1d20d1eba9dbfa992612b86594e4b6d47bb580d2dd6bcb3", size = 44040, upload-time = "2026-03-11T22:18:29.236Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/a7/01b5a0bcded14cd2522b218f2edc3533b0fcbccdea01f3e14a2b699071aa/ujson-5.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:9a5fcbe7b949f2e95c47ea8a80b410fcdf2da61c98553b45a4ee875580418b68", size = 38526, upload-time = "2026-03-11T22:18:30.551Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/3c/5ee154d505d1aad2debc4ba38b1a60ae1949b26cdb5fa070e85e320d6b64/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_10_13_x86_64.whl", hash = "sha256:bf85a00ac3b56a1e7a19c5be7b02b5180a0895ac4d3c234d717a55e86960691c", size = 54494, upload-time = "2026-03-11T22:19:13.035Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/b3/9496ec399ec921e434a93b340bd5052999030b7ac364be4cbe5365ac6b20/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:64df53eef4ac857eb5816a56e2885ccf0d7dff6333c94065c93b39c51063e01d", size = 57999, upload-time = "2026-03-11T22:19:14.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/da/e9ae98133336e7c0d50b43626c3f2327937cecfa354d844e02ac17379ed1/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c0aed6a4439994c9666fb8a5b6c4eac94d4ef6ddc95f9b806a599ef83547e3b", size = 54518, upload-time = "2026-03-11T22:19:15.4Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/10/978d89dded6bb1558cd46ba78f4351198bd2346db8a8ee1a94119022ce40/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efae5df7a8cc8bdb1037b0f786b044ce281081441df5418c3a0f0e1f86fe7bb3", size = 55736, upload-time = "2026-03-11T22:19:16.496Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/25/1df8e6217c92e57a1266bf5be750b1dddc126ee96e53fe959d5693503bc6/ujson-5.12.0-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:8712b61eb1b74a4478cfd1c54f576056199e9f093659334aeb5c4a6b385338e5", size = 44615, upload-time = "2026-03-11T22:19:17.53Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/fa/f4a957dddb99bd68c8be91928c0b6fefa7aa8aafc92c93f5d1e8b32f6702/ujson-5.12.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:871c0e5102e47995b0e37e8df7819a894a6c3da0d097545cd1f9f1f7d7079927", size = 52145, upload-time = "2026-03-11T22:19:18.566Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/6e/50b5cf612de1ca06c7effdc5a5d7e815774dee85a5858f1882c425553b82/ujson-5.12.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:56ba3f7abbd6b0bb282a544dc38406d1a188d8bb9164f49fdb9c2fee62cb29da", size = 49577, upload-time = "2026-03-11T22:19:19.627Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/24/b6713fa9897774502cd4c2d6955bb4933349f7d84c3aa805531c382a4209/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c5a52987a990eb1bae55f9000994f1afdb0326c154fb089992f839ab3c30688", size = 50807, upload-time = "2026-03-11T22:19:20.778Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/b6/c0e0f7901180ef80d16f3a4bccb5dc8b01515a717336a62928963a07b80b/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:adf28d13a33f9d750fe7a78fb481cac298fa257d8863d8727b2ea4455ea41235", size = 56972, upload-time = "2026-03-11T22:19:21.84Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/02/a9/05d91b4295ea7239151eb08cf240e5a2ba969012fda50bc27bcb1ea9cd71/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51acc750ec7a2df786cdc868fb16fa04abd6269a01d58cf59bafc57978773d8e", size = 52045, upload-time = "2026-03-11T22:19:22.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/7a/92047d32bf6f2d9db64605fc32e8eb0e0dd68b671eaafc12a464f69c4af4/ujson-5.12.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:ab9056d94e5db513d9313b34394f3a3b83e6301a581c28ad67773434f3faccab", size = 44053, upload-time = "2026-03-11T22:19:23.918Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
16
codecov.yml
Normal file
16
codecov.yml
Normal file
@ -0,0 +1,16 @@
|
||||
coverage:
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
target: auto
|
||||
|
||||
flags:
|
||||
web:
|
||||
paths:
|
||||
- "web/"
|
||||
carryforward: true
|
||||
|
||||
api:
|
||||
paths:
|
||||
- "api/"
|
||||
carryforward: true
|
||||
@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@ -728,7 +728,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -770,7 +770,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -809,7 +809,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -839,7 +839,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@ -11,6 +11,7 @@ import type { BasicPlan } from '@/app/components/billing/type'
|
||||
import { cleanup, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { ALL_PLANS } from '@/app/components/billing/config'
|
||||
import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher'
|
||||
import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item'
|
||||
@ -21,7 +22,6 @@ let mockAppCtx: Record<string, unknown> = {}
|
||||
const mockFetchSubscriptionUrls = vi.fn()
|
||||
const mockInvoices = vi.fn()
|
||||
const mockOpenAsyncWindow = vi.fn()
|
||||
const mockToastNotify = vi.fn()
|
||||
|
||||
// ─── Context mocks ───────────────────────────────────────────────────────────
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
@ -49,10 +49,6 @@ vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: () => mockOpenAsyncWindow,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: (args: unknown) => mockToastNotify(args) },
|
||||
}))
|
||||
|
||||
// ─── Navigation mocks ───────────────────────────────────────────────────────
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
@ -82,12 +78,15 @@ const renderCloudPlanItem = ({
|
||||
canPay = true,
|
||||
}: RenderCloudPlanItemOptions = {}) => {
|
||||
return render(
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>,
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
@ -96,6 +95,7 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
setupAppContext()
|
||||
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' })
|
||||
mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' })
|
||||
@ -283,11 +283,7 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
// Should not proceed with payment
|
||||
expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled()
|
||||
|
||||
@ -10,12 +10,12 @@
|
||||
import { cleanup, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config'
|
||||
import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item'
|
||||
import { SelfHostedPlan } from '@/app/components/billing/type'
|
||||
|
||||
let mockAppCtx: Record<string, unknown> = {}
|
||||
const mockToastNotify = vi.fn()
|
||||
|
||||
const originalLocation = window.location
|
||||
let assignedHref = ''
|
||||
@ -40,10 +40,6 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({
|
||||
AwsMarketplaceDark: () => <span data-testid="icon-aws-dark" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: (args: unknown) => mockToastNotify(args) },
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({
|
||||
default: ({ plan }: { plan: string }) => (
|
||||
<div data-testid={`self-hosted-list-${plan}`}>Features</div>
|
||||
@ -57,10 +53,20 @@ const setupAppContext = (overrides: Record<string, unknown> = {}) => {
|
||||
}
|
||||
}
|
||||
|
||||
const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<SelfHostedPlanItem plan={plan} />
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Self-Hosted Plan Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
setupAppContext()
|
||||
|
||||
// Mock window.location with minimal getter/setter (Location props are non-enumerable)
|
||||
@ -85,14 +91,14 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
// ─── 1. Plan Rendering ──────────────────────────────────────────────────
|
||||
describe('Plan rendering', () => {
|
||||
it('should render community plan with name and description', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render premium plan with cloud provider icons', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument()
|
||||
expect(screen.getByTestId('icon-azure')).toBeInTheDocument()
|
||||
@ -100,39 +106,39 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
})
|
||||
|
||||
it('should render enterprise plan without cloud provider icons', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
|
||||
expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show price tip for community (free) plan', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show price tip for premium plan', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render features list for each plan', () => {
|
||||
const { unmount: unmount1 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument()
|
||||
unmount1()
|
||||
|
||||
const { unmount: unmount2 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument()
|
||||
unmount2()
|
||||
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show AWS marketplace icon for premium plan button', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument()
|
||||
})
|
||||
@ -142,7 +148,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
describe('Navigation flow', () => {
|
||||
it('should redirect to GitHub when clicking community plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@ -152,7 +158,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
|
||||
it('should redirect to AWS Marketplace when clicking premium plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@ -162,7 +168,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
|
||||
it('should redirect to Typeform when clicking enterprise plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@ -176,15 +182,13 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks community button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
// Should NOT redirect
|
||||
expect(assignedHref).toBe('')
|
||||
@ -193,15 +197,13 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks premium button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
expect(assignedHref).toBe('')
|
||||
})
|
||||
@ -209,15 +211,13 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks enterprise button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
expect(assignedHref).toBe('')
|
||||
})
|
||||
|
||||
@ -1,221 +0,0 @@
|
||||
import {
|
||||
buildGitDiffRevisionArgs,
|
||||
getChangedBranchCoverage,
|
||||
getChangedStatementCoverage,
|
||||
getIgnoredChangedLinesFromSource,
|
||||
normalizeToRepoRelative,
|
||||
parseChangedLineMap,
|
||||
} from '../scripts/check-components-diff-coverage-lib.mjs'
|
||||
|
||||
describe('check-components-diff-coverage helpers', () => {
|
||||
it('should build exact and merge-base git diff revision args', () => {
|
||||
expect(buildGitDiffRevisionArgs('base-sha', 'head-sha', 'exact')).toEqual(['base-sha', 'head-sha'])
|
||||
expect(buildGitDiffRevisionArgs('base-sha', 'head-sha')).toEqual(['base-sha...head-sha'])
|
||||
})
|
||||
|
||||
it('should parse changed line maps from unified diffs', () => {
|
||||
const diff = [
|
||||
'diff --git a/web/app/components/share/a.ts b/web/app/components/share/a.ts',
|
||||
'+++ b/web/app/components/share/a.ts',
|
||||
'@@ -10,0 +11,2 @@',
|
||||
'+const a = 1',
|
||||
'+const b = 2',
|
||||
'diff --git a/web/app/components/base/b.ts b/web/app/components/base/b.ts',
|
||||
'+++ b/web/app/components/base/b.ts',
|
||||
'@@ -20 +21 @@',
|
||||
'+const c = 3',
|
||||
'diff --git a/web/README.md b/web/README.md',
|
||||
'+++ b/web/README.md',
|
||||
'@@ -1 +1 @@',
|
||||
'+ignore me',
|
||||
].join('\n')
|
||||
|
||||
const lineMap = parseChangedLineMap(diff, (filePath: string) => filePath.startsWith('web/app/components/'))
|
||||
|
||||
expect([...lineMap.entries()]).toEqual([
|
||||
['web/app/components/share/a.ts', new Set([11, 12])],
|
||||
['web/app/components/base/b.ts', new Set([21])],
|
||||
])
|
||||
})
|
||||
|
||||
it('should normalize coverage and absolute paths to repo-relative paths', () => {
|
||||
const repoRoot = '/repo'
|
||||
const webRoot = '/repo/web'
|
||||
|
||||
expect(normalizeToRepoRelative('web/app/components/share/a.ts', {
|
||||
appComponentsCoveragePrefix: 'app/components/',
|
||||
appComponentsPrefix: 'web/app/components/',
|
||||
repoRoot,
|
||||
sharedTestPrefix: 'web/__tests__/',
|
||||
webRoot,
|
||||
})).toBe('web/app/components/share/a.ts')
|
||||
|
||||
expect(normalizeToRepoRelative('app/components/share/a.ts', {
|
||||
appComponentsCoveragePrefix: 'app/components/',
|
||||
appComponentsPrefix: 'web/app/components/',
|
||||
repoRoot,
|
||||
sharedTestPrefix: 'web/__tests__/',
|
||||
webRoot,
|
||||
})).toBe('web/app/components/share/a.ts')
|
||||
|
||||
expect(normalizeToRepoRelative('/repo/web/app/components/share/a.ts', {
|
||||
appComponentsCoveragePrefix: 'app/components/',
|
||||
appComponentsPrefix: 'web/app/components/',
|
||||
repoRoot,
|
||||
sharedTestPrefix: 'web/__tests__/',
|
||||
webRoot,
|
||||
})).toBe('web/app/components/share/a.ts')
|
||||
})
|
||||
|
||||
it('should calculate changed statement coverage from changed lines', () => {
|
||||
const entry = {
|
||||
s: { 0: 1, 1: 0 },
|
||||
statementMap: {
|
||||
0: { start: { line: 10 }, end: { line: 10 } },
|
||||
1: { start: { line: 12 }, end: { line: 13 } },
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedStatementCoverage(entry, new Set([10, 12]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 1,
|
||||
total: 2,
|
||||
uncoveredLines: [12],
|
||||
})
|
||||
})
|
||||
|
||||
it('should report the first changed line inside a multi-line uncovered statement', () => {
|
||||
const entry = {
|
||||
s: { 0: 0 },
|
||||
statementMap: {
|
||||
0: { start: { line: 10 }, end: { line: 14 } },
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedStatementCoverage(entry, new Set([13, 14]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 0,
|
||||
total: 1,
|
||||
uncoveredLines: [13],
|
||||
})
|
||||
})
|
||||
|
||||
it('should fail changed lines when a source file has no coverage entry', () => {
|
||||
const coverage = getChangedStatementCoverage(undefined, new Set([42, 43]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 0,
|
||||
total: 2,
|
||||
uncoveredLines: [42, 43],
|
||||
})
|
||||
})
|
||||
|
||||
it('should calculate changed branch coverage using changed branch definitions', () => {
|
||||
const entry = {
|
||||
b: {
|
||||
0: [1, 0],
|
||||
},
|
||||
branchMap: {
|
||||
0: {
|
||||
line: 20,
|
||||
loc: { start: { line: 20 }, end: { line: 20 } },
|
||||
locations: [
|
||||
{ start: { line: 20 }, end: { line: 20 } },
|
||||
{ start: { line: 21 }, end: { line: 21 } },
|
||||
],
|
||||
type: 'if',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedBranchCoverage(entry, new Set([20]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 1,
|
||||
total: 2,
|
||||
uncoveredBranches: [
|
||||
{ armIndex: 1, line: 21 },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it('should report the first changed line inside a multi-line uncovered branch arm', () => {
|
||||
const entry = {
|
||||
b: {
|
||||
0: [0, 0],
|
||||
},
|
||||
branchMap: {
|
||||
0: {
|
||||
line: 30,
|
||||
loc: { start: { line: 30 }, end: { line: 35 } },
|
||||
locations: [
|
||||
{ start: { line: 31 }, end: { line: 34 } },
|
||||
{ start: { line: 35 }, end: { line: 38 } },
|
||||
],
|
||||
type: 'if',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedBranchCoverage(entry, new Set([33]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 0,
|
||||
total: 1,
|
||||
uncoveredBranches: [
|
||||
{ armIndex: 0, line: 33 },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it('should require all branch arms when the branch condition changes', () => {
|
||||
const entry = {
|
||||
b: {
|
||||
0: [0, 0],
|
||||
},
|
||||
branchMap: {
|
||||
0: {
|
||||
line: 30,
|
||||
loc: { start: { line: 30 }, end: { line: 35 } },
|
||||
locations: [
|
||||
{ start: { line: 31 }, end: { line: 34 } },
|
||||
{ start: { line: 35 }, end: { line: 38 } },
|
||||
],
|
||||
type: 'if',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedBranchCoverage(entry, new Set([30]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 0,
|
||||
total: 2,
|
||||
uncoveredBranches: [
|
||||
{ armIndex: 0, line: 31 },
|
||||
{ armIndex: 1, line: 35 },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it('should ignore changed lines with valid pragma reasons and report invalid pragmas', () => {
|
||||
const sourceCode = [
|
||||
'const a = 1',
|
||||
'const b = 2 // diff-coverage-ignore-line: defensive fallback',
|
||||
'const c = 3 // diff-coverage-ignore-line:',
|
||||
'const d = 4 // diff-coverage-ignore-line: not changed',
|
||||
].join('\n')
|
||||
|
||||
const result = getIgnoredChangedLinesFromSource(sourceCode, new Set([2, 3]))
|
||||
|
||||
expect([...result.effectiveChangedLines]).toEqual([3])
|
||||
expect([...result.ignoredLines.entries()]).toEqual([
|
||||
[2, 'defensive fallback'],
|
||||
])
|
||||
expect(result.invalidPragmas).toEqual([
|
||||
{ line: 3, reason: 'missing ignore reason' },
|
||||
])
|
||||
})
|
||||
})
|
||||
@ -1,115 +0,0 @@
|
||||
import fs from 'node:fs'
|
||||
import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import {
|
||||
collectComponentCoverageExcludedFiles,
|
||||
COMPONENT_COVERAGE_EXCLUDE_LABEL,
|
||||
getComponentCoverageExclusionReasons,
|
||||
} from '../scripts/component-coverage-filters.mjs'
|
||||
|
||||
describe('component coverage filters', () => {
|
||||
describe('getComponentCoverageExclusionReasons', () => {
|
||||
it('should exclude type-only files by basename', () => {
|
||||
expect(
|
||||
getComponentCoverageExclusionReasons(
|
||||
'web/app/components/share/text-generation/types.ts',
|
||||
'export type ShareMode = "run-once" | "run-batch"',
|
||||
),
|
||||
).toContain('type-only')
|
||||
})
|
||||
|
||||
it('should exclude pure barrel files', () => {
|
||||
expect(
|
||||
getComponentCoverageExclusionReasons(
|
||||
'web/app/components/base/amplitude/index.ts',
|
||||
[
|
||||
'export { default } from "./AmplitudeProvider"',
|
||||
'export { resetUser, trackEvent } from "./utils"',
|
||||
].join('\n'),
|
||||
),
|
||||
).toContain('pure-barrel')
|
||||
})
|
||||
|
||||
it('should exclude generated files from marker comments', () => {
|
||||
expect(
|
||||
getComponentCoverageExclusionReasons(
|
||||
'web/app/components/base/icons/src/vender/workflow/Answer.tsx',
|
||||
[
|
||||
'// GENERATE BY script',
|
||||
'// DON NOT EDIT IT MANUALLY',
|
||||
'export default function Icon() {',
|
||||
' return null',
|
||||
'}',
|
||||
].join('\n'),
|
||||
),
|
||||
).toContain('generated')
|
||||
})
|
||||
|
||||
it('should exclude pure static files with exported constants only', () => {
|
||||
expect(
|
||||
getComponentCoverageExclusionReasons(
|
||||
'web/app/components/workflow/note-node/constants.ts',
|
||||
[
|
||||
'import { NoteTheme } from "./types"',
|
||||
'export const CUSTOM_NOTE_NODE = "custom-note"',
|
||||
'export const THEME_MAP = {',
|
||||
' [NoteTheme.blue]: { title: "bg-blue-100" },',
|
||||
'}',
|
||||
].join('\n'),
|
||||
),
|
||||
).toContain('pure-static')
|
||||
})
|
||||
|
||||
it('should keep runtime logic files tracked', () => {
|
||||
expect(
|
||||
getComponentCoverageExclusionReasons(
|
||||
'web/app/components/workflow/nodes/trigger-schedule/default.ts',
|
||||
[
|
||||
'const validate = (value: string) => value.trim()',
|
||||
'export const nodeDefault = {',
|
||||
' value: validate("x"),',
|
||||
'}',
|
||||
].join('\n'),
|
||||
),
|
||||
).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('collectComponentCoverageExcludedFiles', () => {
|
||||
const tempDirs: string[] = []
|
||||
|
||||
afterEach(() => {
|
||||
for (const dir of tempDirs)
|
||||
fs.rmSync(dir, { recursive: true, force: true })
|
||||
tempDirs.length = 0
|
||||
})
|
||||
|
||||
it('should collect excluded files for coverage config and keep runtime files out', () => {
|
||||
const rootDir = fs.mkdtempSync(path.join(os.tmpdir(), 'component-coverage-filters-'))
|
||||
tempDirs.push(rootDir)
|
||||
|
||||
fs.mkdirSync(path.join(rootDir, 'barrel'), { recursive: true })
|
||||
fs.mkdirSync(path.join(rootDir, 'icons'), { recursive: true })
|
||||
fs.mkdirSync(path.join(rootDir, 'static'), { recursive: true })
|
||||
fs.mkdirSync(path.join(rootDir, 'runtime'), { recursive: true })
|
||||
|
||||
fs.writeFileSync(path.join(rootDir, 'barrel', 'index.ts'), 'export { default } from "./Button"\n')
|
||||
fs.writeFileSync(path.join(rootDir, 'icons', 'generated-icon.tsx'), '// @generated\nexport default function Icon() { return null }\n')
|
||||
fs.writeFileSync(path.join(rootDir, 'static', 'constants.ts'), 'export const COLORS = { primary: "#fff" }\n')
|
||||
fs.writeFileSync(path.join(rootDir, 'runtime', 'config.ts'), 'export const config = makeConfig()\n')
|
||||
fs.writeFileSync(path.join(rootDir, 'runtime', 'types.ts'), 'export type Config = { value: string }\n')
|
||||
|
||||
expect(collectComponentCoverageExcludedFiles(rootDir, { pathPrefix: 'app/components' })).toEqual([
|
||||
'app/components/barrel/index.ts',
|
||||
'app/components/icons/generated-icon.tsx',
|
||||
'app/components/runtime/types.ts',
|
||||
'app/components/static/constants.ts',
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
it('should describe the excluded coverage categories', () => {
|
||||
expect(COMPONENT_COVERAGE_EXCLUDE_LABEL).toBe('type-only files, pure barrel files, generated files, pure static files')
|
||||
})
|
||||
})
|
||||
@ -1,72 +0,0 @@
|
||||
import {
|
||||
getCoverageStats,
|
||||
isRelevantTestFile,
|
||||
isTrackedComponentSourceFile,
|
||||
loadTrackedCoverageEntries,
|
||||
} from '../scripts/components-coverage-common.mjs'
|
||||
|
||||
describe('components coverage common helpers', () => {
|
||||
it('should identify tracked component source files and relevant tests', () => {
|
||||
const excludedComponentCoverageFiles = new Set([
|
||||
'web/app/components/share/types.ts',
|
||||
])
|
||||
|
||||
expect(isTrackedComponentSourceFile('web/app/components/share/index.tsx', excludedComponentCoverageFiles)).toBe(true)
|
||||
expect(isTrackedComponentSourceFile('web/app/components/share/types.ts', excludedComponentCoverageFiles)).toBe(false)
|
||||
expect(isTrackedComponentSourceFile('web/app/components/provider/index.tsx', excludedComponentCoverageFiles)).toBe(false)
|
||||
|
||||
expect(isRelevantTestFile('web/__tests__/share/text-generation-run-once-flow.test.tsx')).toBe(true)
|
||||
expect(isRelevantTestFile('web/app/components/share/__tests__/index.spec.tsx')).toBe(true)
|
||||
expect(isRelevantTestFile('web/utils/format.spec.ts')).toBe(false)
|
||||
})
|
||||
|
||||
it('should load only tracked coverage entries from mixed coverage paths', () => {
|
||||
const context = {
|
||||
excludedComponentCoverageFiles: new Set([
|
||||
'web/app/components/share/types.ts',
|
||||
]),
|
||||
repoRoot: '/repo',
|
||||
webRoot: '/repo/web',
|
||||
}
|
||||
const coverage = {
|
||||
'/repo/web/app/components/provider/index.tsx': {
|
||||
path: '/repo/web/app/components/provider/index.tsx',
|
||||
statementMap: { 0: { start: { line: 1 }, end: { line: 1 } } },
|
||||
s: { 0: 1 },
|
||||
},
|
||||
'app/components/share/index.tsx': {
|
||||
path: 'app/components/share/index.tsx',
|
||||
statementMap: { 0: { start: { line: 2 }, end: { line: 2 } } },
|
||||
s: { 0: 1 },
|
||||
},
|
||||
'app/components/share/types.ts': {
|
||||
path: 'app/components/share/types.ts',
|
||||
statementMap: { 0: { start: { line: 3 }, end: { line: 3 } } },
|
||||
s: { 0: 1 },
|
||||
},
|
||||
}
|
||||
|
||||
expect([...loadTrackedCoverageEntries(coverage, context).keys()]).toEqual([
|
||||
'web/app/components/share/index.tsx',
|
||||
])
|
||||
})
|
||||
|
||||
it('should calculate coverage stats using statement-derived line hits', () => {
|
||||
const entry = {
|
||||
b: { 0: [1, 0] },
|
||||
f: { 0: 1, 1: 0 },
|
||||
s: { 0: 1, 1: 0 },
|
||||
statementMap: {
|
||||
0: { start: { line: 10 }, end: { line: 10 } },
|
||||
1: { start: { line: 12 }, end: { line: 13 } },
|
||||
},
|
||||
}
|
||||
|
||||
expect(getCoverageStats(entry)).toEqual({
|
||||
branches: { covered: 1, total: 2 },
|
||||
functions: { covered: 1, total: 2 },
|
||||
lines: { covered: 1, total: 2 },
|
||||
statements: { covered: 1, total: 2 },
|
||||
})
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user