diff --git a/.claude/skills/frontend-testing/SKILL.md b/.claude/skills/frontend-testing/SKILL.md index 06cb672141..cd775007a0 100644 --- a/.claude/skills/frontend-testing/SKILL.md +++ b/.claude/skills/frontend-testing/SKILL.md @@ -1,5 +1,5 @@ --- -name: Dify Frontend Testing +name: frontend-testing description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests. --- @@ -178,7 +178,7 @@ Process in this order for multi-file testing: - **500+ lines**: Consider splitting before testing - **Many dependencies**: Extract logic into hooks first -> 📖 See `guides/workflow.md` for complete workflow details and todo list format. +> 📖 See `references/workflow.md` for complete workflow details and todo list format. ## Testing Strategy @@ -289,17 +289,18 @@ For each test file generated, aim for: - ✅ **>95%** branch coverage - ✅ **>95%** line coverage -> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`. +> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`. ## Detailed Guides For more detailed information, refer to: -- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing) -- `guides/mocking.md` - Mock patterns and best practices -- `guides/async-testing.md` - Async operations and API calls -- `guides/domain-components.md` - Workflow, Dataset, Configuration testing -- `guides/common-patterns.md` - Frequently used testing patterns +- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing) +- `references/mocking.md` - Mock patterns and best practices +- `references/async-testing.md` - Async operations and API calls +- `references/domain-components.md` - Workflow, Dataset, Configuration testing +- `references/common-patterns.md` - Frequently used testing patterns +- `references/checklist.md` - Test generation checklist and validation steps ## Authoritative References diff --git a/.claude/skills/frontend-testing/templates/component-test.template.tsx b/.claude/skills/frontend-testing/assets/component-test.template.tsx similarity index 100% rename from .claude/skills/frontend-testing/templates/component-test.template.tsx rename to .claude/skills/frontend-testing/assets/component-test.template.tsx diff --git a/.claude/skills/frontend-testing/templates/hook-test.template.ts b/.claude/skills/frontend-testing/assets/hook-test.template.ts similarity index 100% rename from .claude/skills/frontend-testing/templates/hook-test.template.ts rename to .claude/skills/frontend-testing/assets/hook-test.template.ts diff --git a/.claude/skills/frontend-testing/templates/utility-test.template.ts b/.claude/skills/frontend-testing/assets/utility-test.template.ts similarity index 100% rename from .claude/skills/frontend-testing/templates/utility-test.template.ts rename to .claude/skills/frontend-testing/assets/utility-test.template.ts diff --git a/.claude/skills/frontend-testing/guides/async-testing.md b/.claude/skills/frontend-testing/references/async-testing.md similarity index 100% rename from .claude/skills/frontend-testing/guides/async-testing.md rename to .claude/skills/frontend-testing/references/async-testing.md diff --git a/.claude/skills/frontend-testing/CHECKLIST.md b/.claude/skills/frontend-testing/references/checklist.md similarity index 100% rename from .claude/skills/frontend-testing/CHECKLIST.md rename to .claude/skills/frontend-testing/references/checklist.md diff --git a/.claude/skills/frontend-testing/guides/common-patterns.md b/.claude/skills/frontend-testing/references/common-patterns.md similarity index 100% rename from .claude/skills/frontend-testing/guides/common-patterns.md rename to .claude/skills/frontend-testing/references/common-patterns.md diff --git a/.claude/skills/frontend-testing/guides/domain-components.md b/.claude/skills/frontend-testing/references/domain-components.md similarity index 100% rename from .claude/skills/frontend-testing/guides/domain-components.md rename to .claude/skills/frontend-testing/references/domain-components.md diff --git a/.claude/skills/frontend-testing/guides/mocking.md b/.claude/skills/frontend-testing/references/mocking.md similarity index 100% rename from .claude/skills/frontend-testing/guides/mocking.md rename to .claude/skills/frontend-testing/references/mocking.md diff --git a/.claude/skills/frontend-testing/guides/workflow.md b/.claude/skills/frontend-testing/references/workflow.md similarity index 100% rename from .claude/skills/frontend-testing/guides/workflow.md rename to .claude/skills/frontend-testing/references/workflow.md diff --git a/.codex/skills b/.codex/skills new file mode 120000 index 0000000000..454b8427cd --- /dev/null +++ b/.codex/skills @@ -0,0 +1 @@ +../.claude/skills \ No newline at end of file diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index a26fd076ed..ce9135476f 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -6,7 +6,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d6f326d4dc..4bc4f085c2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,6 +6,12 @@ * @crazywoola @laipz8200 @Yeuoly +# CODEOWNERS file +.github/CODEOWNERS @laipz8200 @crazywoola + +# Docs +docs/ @crazywoola + # Backend (default owner, more specific rules below will override) api/ @QuantumGhost @@ -116,11 +122,17 @@ api/controllers/console/feature.py @GarfieldDai @GareArc api/controllers/web/feature.py @GarfieldDai @GareArc # Backend - Database Migrations -api/migrations/ @snakevash @laipz8200 +api/migrations/ @snakevash @laipz8200 @MRZHUH + +# Backend - Vector DB Middleware +api/configs/middleware/vdb/* @JohnJyong # Frontend web/ @iamjoel +# Frontend - Web Tests +.github/workflows/web-tests.yml @iamjoel + # Frontend - App - Orchestration web/app/components/workflow/ @iamjoel @zxhlyh web/app/components/workflow-app/ @iamjoel @zxhlyh @@ -192,6 +204,7 @@ web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d web/app/signin/ @douxc @iamjoel web/app/signup/ @douxc @iamjoel web/app/reset-password/ @douxc @iamjoel + web/app/install/ @douxc @iamjoel web/app/init/ @douxc @iamjoel web/app/forgot-password/ @douxc @iamjoel @@ -232,3 +245,6 @@ web/app/education-apply/ @iamjoel @zxhlyh # Frontend - Workspace web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh + +# Docker +docker/* @laipz8200 diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index d7a58ce93d..bafac7bd13 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -66,7 +66,7 @@ jobs: # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | - uvx --python 3.13 mdformat . --exclude ".claude/skills/**" + uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md" - name: Install pnpm uses: pnpm/action-setup@v4 @@ -79,7 +79,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies working-directory: ./web diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5a8a34be79..2fb8121f74 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -90,7 +90,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index fe8e2ebc2b..8bb82d5d44 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -55,7 +55,7 @@ jobs: with: node-version: 'lts/*' cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Install dependencies if: env.FILES_CHANGED == 'true' diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 3313e58614..b1f32f96c2 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -13,6 +13,7 @@ jobs: runs-on: ubuntu-latest defaults: run: + shell: bash working-directory: ./web steps: @@ -21,14 +22,7 @@ jobs: with: persist-credentials: false - - name: Check changed files - id: changed-files - uses: tj-actions/changed-files@v46 - with: - files: web/** - - name: Install pnpm - if: steps.changed-files.outputs.any_changed == 'true' uses: pnpm/action-setup@v4 with: package_json_file: web/package.json @@ -36,23 +30,355 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v4 - if: steps.changed-files.outputs.any_changed == 'true' with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Restore Jest cache + uses: actions/cache@v4 + with: + path: web/.cache/jest + key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }} + restore-keys: | + ${{ runner.os }}-jest- - name: Install dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm install --frozen-lockfile - name: Check i18n types synchronization - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm run check:i18n-types - name: Run tests - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm test + run: | + pnpm exec jest \ + --ci \ + --maxWorkers=100% \ + --coverage \ + --passWithNoTests + + - name: Coverage Summary + if: always() + id: coverage-summary + run: | + set -eo pipefail + + COVERAGE_FILE="coverage/coverage-final.json" + COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json" + + if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then + echo "has_coverage=false" >> "$GITHUB_OUTPUT" + echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY" + echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY" + exit 0 + fi + + echo "has_coverage=true" >> "$GITHUB_OUTPUT" + + node <<'NODE' >> "$GITHUB_STEP_SUMMARY" + const fs = require('fs'); + const path = require('path'); + let libCoverage = null; + + try { + libCoverage = require('istanbul-lib-coverage'); + } catch (error) { + libCoverage = null; + } + + const summaryPath = path.join('coverage', 'coverage-summary.json'); + const finalPath = path.join('coverage', 'coverage-final.json'); + + const hasSummary = fs.existsSync(summaryPath); + const hasFinal = fs.existsSync(finalPath); + + if (!hasSummary && !hasFinal) { + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('No coverage data found.'); + process.exit(0); + } + + const summary = hasSummary + ? JSON.parse(fs.readFileSync(summaryPath, 'utf8')) + : null; + const coverage = hasFinal + ? JSON.parse(fs.readFileSync(finalPath, 'utf8')) + : null; + + const getLineCoverageFromStatements = (statementMap, statementHits) => { + const lineHits = {}; + + if (!statementMap || !statementHits) { + return lineHits; + } + + Object.entries(statementMap).forEach(([key, statement]) => { + const line = statement?.start?.line; + if (!line) { + return; + } + const hits = statementHits[key] ?? 0; + const previous = lineHits[line]; + lineHits[line] = previous === undefined ? hits : Math.max(previous, hits); + }); + + return lineHits; + }; + + const getFileCoverage = (entry) => ( + libCoverage ? libCoverage.createFileCoverage(entry) : null + ); + + const getLineHits = (entry, fileCoverage) => { + const lineHits = entry.l ?? {}; + if (Object.keys(lineHits).length > 0) { + return lineHits; + } + if (fileCoverage) { + return fileCoverage.getLineCoverage(); + } + return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {}); + }; + + const getUncoveredLines = (entry, fileCoverage, lineHits) => { + if (lineHits && Object.keys(lineHits).length > 0) { + return Object.entries(lineHits) + .filter(([, count]) => count === 0) + .map(([line]) => Number(line)) + .sort((a, b) => a - b); + } + if (fileCoverage) { + return fileCoverage.getUncoveredLines(); + } + return []; + }; + + const totals = { + lines: { covered: 0, total: 0 }, + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + }; + const fileSummaries = []; + + if (summary) { + const totalEntry = summary.total ?? {}; + ['lines', 'statements', 'branches', 'functions'].forEach((key) => { + if (totalEntry[key]) { + totals[key].covered = totalEntry[key].covered ?? 0; + totals[key].total = totalEntry[key].total ?? 0; + } + }); + + Object.entries(summary) + .filter(([file]) => file !== 'total') + .forEach(([file, data]) => { + fileSummaries.push({ + file, + pct: data.lines?.pct ?? data.statements?.pct ?? 0, + lines: { + covered: data.lines?.covered ?? 0, + total: data.lines?.total ?? 0, + }, + }); + }); + } else if (coverage) { + Object.entries(coverage).forEach(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + totals.lines.total += lineTotal; + totals.lines.covered += lineCovered; + totals.statements.total += statementTotal; + totals.statements.covered += statementCovered; + totals.branches.total += branchTotal; + totals.branches.covered += branchCovered; + totals.functions.total += functionTotal; + totals.functions.covered += functionCovered; + + const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0); + + fileSummaries.push({ + file, + pct: pct(lineCovered || statementCovered, lineTotal || statementTotal), + lines: { + covered: lineCovered || statementCovered, + total: lineTotal || statementTotal, + }, + }); + }); + } + + const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00'); + + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('| Metric | Coverage | Covered / Total |'); + console.log('|--------|----------|-----------------|'); + console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`); + console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`); + console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`); + console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`); + + console.log(''); + console.log('
File coverage (lowest lines first)'); + console.log(''); + console.log('```'); + fileSummaries + .sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total)) + .slice(0, 25) + .forEach(({ file, pct, lines }) => { + console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`); + }); + console.log('```'); + console.log('
'); + + if (coverage) { + const pctValue = (covered, tot) => { + if (tot === 0) { + return '0'; + } + return ((covered / tot) * 100) + .toFixed(2) + .replace(/\.?0+$/, ''); + }; + + const formatLineRanges = (lines) => { + if (lines.length === 0) { + return ''; + } + const ranges = []; + let start = lines[0]; + let end = lines[0]; + + for (let i = 1; i < lines.length; i += 1) { + const current = lines[i]; + if (current === end + 1) { + end = current; + continue; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + start = current; + end = current; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + return ranges.join(','); + }; + + const tableTotals = { + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + lines: { covered: 0, total: 0 }, + }; + const tableRows = Object.entries(coverage) + .map(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + tableTotals.lines.total += lineTotal; + tableTotals.lines.covered += lineCovered; + tableTotals.statements.total += statementTotal; + tableTotals.statements.covered += statementCovered; + tableTotals.branches.total += branchTotal; + tableTotals.branches.covered += branchCovered; + tableTotals.functions.total += functionTotal; + tableTotals.functions.covered += functionCovered; + + const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits); + + const filePath = entry.path ?? file; + const relativePath = path.isAbsolute(filePath) + ? path.relative(process.cwd(), filePath) + : filePath; + + return { + file: relativePath || file, + statements: pctValue(statementCovered, statementTotal), + branches: pctValue(branchCovered, branchTotal), + functions: pctValue(functionCovered, functionTotal), + lines: pctValue(lineCovered, lineTotal), + uncovered: formatLineRanges(uncoveredLines), + }; + }) + .sort((a, b) => a.file.localeCompare(b.file)); + + const columns = [ + { key: 'file', header: 'File', align: 'left' }, + { key: 'statements', header: '% Stmts', align: 'right' }, + { key: 'branches', header: '% Branch', align: 'right' }, + { key: 'functions', header: '% Funcs', align: 'right' }, + { key: 'lines', header: '% Lines', align: 'right' }, + { key: 'uncovered', header: 'Uncovered Line #s', align: 'left' }, + ]; + + const allFilesRow = { + file: 'All files', + statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total), + branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total), + functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total), + lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total), + uncovered: '', + }; + + const rowsForOutput = [allFilesRow, ...tableRows]; + const formatRow = (row) => `| ${columns + .map(({ key }) => String(row[key] ?? '')) + .join(' | ')} |`; + const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`; + const dividerRow = `| ${columns + .map(({ align }) => (align === 'right' ? '---:' : ':---')) + .join(' | ')} |`; + + console.log(''); + console.log('
Jest coverage table'); + console.log(''); + console.log(headerRow); + console.log(dividerRow); + rowsForOutput.forEach((row) => console.log(formatRow(row))); + console.log('
'); + } + NODE + + - name: Upload Coverage Artifact + if: steps.coverage-summary.outputs.has_coverage == 'true' + uses: actions/upload-artifact@v4 + with: + name: web-coverage-report + path: web/coverage + retention-days: 30 + if-no-files-found: error diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index cb934d01b5..bdded1e73e 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -37,7 +37,7 @@ "-c", "1", "-Q", - "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor", + "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention", "--loglevel", "INFO" ], diff --git a/api/.env.example b/api/.env.example index 43fe76bb11..b87d9c7b02 100644 --- a/api/.env.example +++ b/api/.env.example @@ -690,3 +690,8 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 # Maximum number of concurrent annotation import tasks per tenant ANNOTATION_IMPORT_MAX_CONCURRENT=5 + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/api/README.md b/api/README.md index 2dab2ec6e6..794b05d3af 100644 --- a/api/README.md +++ b/api/README.md @@ -84,7 +84,7 @@ 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash -uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor +uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention ``` Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index e16ca52f46..43dddbd011 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -218,7 +218,7 @@ class PluginConfig(BaseSettings): PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field( description="Timeout in seconds for requests to the plugin daemon (set to None to disable)", - default=300.0, + default=600.0, ) INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") @@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings): ) +class SandboxExpiredRecordsCleanConfig(BaseSettings): + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field( + description="Graceful period in days for sandbox records clean after subscription expiration", + default=21, + ) + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field( + description="Maximum number of records to process in each batch", + default=1000, + ) + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field( + description="Retention days for sandbox expired workflow_run records and message records", + default=30, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1295,6 +1310,7 @@ class FeatureConfig( PositionConfig, RagEtlConfig, RepositoryConfig, + SandboxExpiredRecordsCleanConfig, SecurityConfig, TenantIsolatedTaskQueueConfig, ToolConfig, diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ea21c4480d..8ceb896d4f 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 5901eca915..a6e5b2822a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -40,7 +40,7 @@ from .. import console_ns logger = logging.getLogger(__name__) -class CompletionMessagePayload(BaseModel): +class CompletionMessageExplorePayload(BaseModel): inputs: dict[str, Any] query: str = "" files: list[dict[str, Any]] | None = None @@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel): raise ValueError("must be a valid UUID") from exc -register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) +register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload) # define completion api for user @@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) endpoint="installed_app_completion", ) class CompletionApi(InstalledAppResource): - @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) + @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) + payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) streaming = payload.response_mode == "streaming" diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 92da591ab4..51995b8b8a 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,5 +1,4 @@ from typing import Any -from uuid import UUID from flask import request from flask_restx import marshal_with @@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from libs.helper import UUIDStrOrEmpty from libs.login import current_user from models import Account from models.model import AppMode @@ -24,7 +24,7 @@ from .. import console_ns class ConversationListQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) pinned: bool | None = None diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3c95779475..e42db10ba6 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,7 +2,8 @@ import logging from typing import Any from flask import request -from flask_restx import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -18,6 +19,15 @@ from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService + +class InstalledAppCreatePayload(BaseModel): + app_id: str + + +class InstalledAppUpdatePayload(BaseModel): + is_pinned: bool | None = None + + logger = logging.getLogger(__name__) @@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") def post(self): - parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") - args = parser.parse_args() + payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() if recommended_app is None: - raise NotFound("App not found") + raise NotFound("Recommended app not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == args["app_id"]).first() + app = db.session.query(App).where(App.id == payload.app_id).first() if app is None: - raise NotFound("App not found") + raise NotFound("App entity not found") if not app.is_public: raise Forbidden("You can't install a non-public app") installed_app = ( db.session.query(InstalledApp) - .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) .first() ) @@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource): recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args["app_id"], + app_id=payload.app_id, tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, @@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource): return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): - parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) - args = parser.parse_args() + payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {}) commit_args = False - if "is_pinned" in args: - installed_app.is_pinned = args["is_pinned"] + if payload.is_pinned is not None: + installed_app.is_pinned = payload.is_pinned commit_args = True if commit_args: diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 17cfc3ff4b..e9fbb515e4 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,31 +1,40 @@ +from typing import Literal + from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required -from models.model import Tag from services.tag_service import TagService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name +class TagBasePayload(BaseModel): + name: str = Field(description="Tag name", min_length=1, max_length=50) + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") -parser_tags = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=_validate_name, - ) - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") +class TagBindingPayload(BaseModel): + tag_ids: list[str] = Field(description="Tag IDs to bind") + target_id: str = Field(description="Target ID to bind tags to") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +class TagBindingRemovePayload(BaseModel): + tag_id: str = Field(description="Tag ID to remove") + target_id: str = Field(description="Target ID to unbind tag from") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +register_schema_models( + console_ns, + TagBasePayload, + TagBindingPayload, + TagBindingRemovePayload, ) @@ -43,7 +52,7 @@ class TagListApi(Resource): return tags, 200 - @console_ns.expect(parser_tags) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -53,22 +62,17 @@ class TagListApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tags.parse_args() - tag = TagService.save_tags(args) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.save_tags(payload.model_dump()) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 -parser_tag_id = reqparse.RequestParser().add_argument( - "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name -) - - @console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): - @console_ns.expect(parser_tag_id) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tag_id.parse_args() - tag = TagService.update_tags(args, tag_id) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(payload.model_dump(), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource): return 204 -parser_create = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_create.parse_args() - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding(payload.model_dump()) return {"result": "success"}, 200 -parser_remove = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): - @console_ns.expect(parser_remove) + @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_remove.parse_args() - TagService.delete_tag_binding(args) + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding(payload.model_dump()) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2c54aa5a20..a2fc45c29c 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,6 +18,7 @@ from controllers.console.wraps import ( setup_required, ) from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration +from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient @@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource): configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None - # Create provider + # Create provider in transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) result = service.create_provider( @@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource): configuration=configuration, authentication=authentication, ) - return jsonable_encoder(result) + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(tenant_id) + + return jsonable_encoder(result) @console_ns.expect(parser_mcp_put) @setup_required @@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource): authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() - # Step 1: Validate server URL change if needed (includes URL format validation and network operation) - validation_result = None + # Step 1: Get provider data for URL validation (short-lived session, no network I/O) + validation_data = None with Session(db.engine) as session: service = MCPToolManageService(session=session) - validation_result = service.validate_server_url_change( - tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + validation_data = service.get_provider_for_url_validation( + tenant_id=current_tenant_id, provider_id=args["provider_id"] ) - # No need to check for errors here, exceptions will be raised directly + # Step 2: Perform URL validation with network I/O OUTSIDE of any database session + # This prevents holding database locks during potentially slow network operations + validation_result = MCPToolManageService.validate_server_url_standalone( + tenant_id=current_tenant_id, + new_server_url=args["server_url"], + validation_data=validation_data, + ) - # Step 2: Perform database update in a transaction + # Step 3: Perform database update in a transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.update_provider( @@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource): authentication=authentication, validation_result=validation_result, ) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} @console_ns.expect(parser_mcp_delete) @setup_required @@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} parser_auth = ( diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 7692aeed23..4f91f40c55 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: RetrievalModel | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b9fef48c4d..15828cc208 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import fields, marshal_with, reqparse +from flask_restx import fields, marshal_with +from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError import services @@ -20,6 +21,7 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService from services.errors.audio import ( @@ -29,6 +31,25 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) +from ..common.schema import register_schema_models + + +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, TextToAudioPayload) + logger = logging.getLogger(__name__) @@ -88,6 +109,7 @@ class AudioApi(WebApiResource): @web_ns.route("/text-to-audio") class TextApi(WebApiResource): + @web_ns.expect(web_ns.models[TextToAudioPayload.__name__]) @web_ns.doc("Text to Audio") @web_ns.doc(description="Convert text to audio using text-to-speech service.") @web_ns.doc( @@ -102,18 +124,11 @@ class TextApi(WebApiResource): def post(self, app_model: App, end_user): """Convert text to audio""" try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(web_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e8a4698375..a97d745471 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,9 +1,11 @@ import logging +from typing import Any, Literal -from flask_restx import reqparse +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, @@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the completion") + query: str = Field(default="", description="Query text for completion") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + retriever_from: str = Field(default="web_app", description="Source of retriever") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the chat") + query: str = Field(description="User query/message") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + conversation_id: str | None = Field(default=None, description="Conversation ID") + parent_message_id: str | None = Field(default=None, description="Parent message ID") + retriever_from: str = Field(default="web_app", description="Source of retriever") + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload) + + # define completion api for user @web_ns.route("/completion-messages") class CompletionApi(WebApiResource): @web_ns.doc("Create Completion Message") @web_ns.doc(description="Create a completion message for text generation applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, - "query": {"description": "Query text for completion", "type": "string", "required": False}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -67,18 +88,10 @@ class CompletionApi(WebApiResource): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = CompletionMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: @@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): @web_ns.doc("Create Chat Message") @web_ns.doc(description="Create a chat message for conversational applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, - "query": {"description": "User query/message", "type": "string", "required": True}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": False}, - "parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False}, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[ChatMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -173,20 +171,10 @@ class ChatApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = ChatMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 93f2742599..307af3747c 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,3 +1,4 @@ +import json from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal @@ -120,7 +121,7 @@ class VariableEntity(BaseModel): allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) + json_schema: str | None = Field(default=None) @field_validator("description", mode="before") @classmethod @@ -134,11 +135,17 @@ class VariableEntity(BaseModel): @field_validator("json_schema") @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: + def validate_json_schema(cls, schema: str | None) -> str | None: if schema is None: return None + try: - Draft7Validator.check_schema(schema) + json_schema = json.loads(schema) + except json.JSONDecodeError: + raise ValueError(f"invalid json_schema value {schema}") + + try: + Draft7Validator.check_schema(json_schema) except SchemaError as e: raise ValueError(f"Invalid JSON schema: {e.message}") return schema diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 1b0474142e..02d58a07d1 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,3 +1,4 @@ +import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Union, final @@ -175,6 +176,13 @@ class BaseAppGenerator: value = True elif value == 0: value = False + case VariableEntityType.JSON_OBJECT: + if not isinstance(value, str): + raise ValueError(f"{variable_entity.variable} in input form must be a string") + try: + json.loads(value) + except json.JSONDecodeError: + raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object") case _: raise AssertionError("this statement should be unreachable.") diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5c169f4db1..5bb93fa44a 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): + event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, + event_type=event_type, ) else: yield self._agent_message_to_stream_response( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2e6f92efa5..0e7f300cee 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import select +from sqlalchemy import exists, select from sqlalchemy.orm import Session from configs import dify_config @@ -54,6 +54,20 @@ class MessageCycleManager: ): self._application_generate_entity = application_generate_entity self._task_state = task_state + self._message_has_file: set[str] = set() + + def get_message_event_type(self, message_id: str) -> StreamEvent: + if message_id in self._message_has_file: + return StreamEvent.MESSAGE_FILE + + with Session(db.engine, expire_on_commit=False) as session: + has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + + if has_file: + self._message_has_file.add(message_id) + return StreamEvent.MESSAGE_FILE + + return StreamEvent.MESSAGE def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: """ @@ -214,7 +228,11 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: list[str] | None = None + self, + answer: str, + message_id: str, + from_variable_selector: list[str] | None = None, + event_type: StreamEvent | None = None, ) -> MessageStreamResponse: """ Message to stream response. @@ -222,16 +240,12 @@ class MessageCycleManager: :param message_id: message id :return: """ - with Session(db.engine, expire_on_commit=False) as session: - message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id)) - event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE - return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, - event=event_type, + event=event_type or StreamEvent.MESSAGE, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 92787b39dd..aef1afb235 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls( """ Build a list of URLs to try for Protected Resource Metadata discovery. - Per SEP-985, supports fallback when discovery fails at one URL. + Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL. + Priority order: + 1. URL from WWW-Authenticate header (if provided) + 2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp + 3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource """ urls = [] @@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls( # Fallback: construct from server URL parsed = urlparse(server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" - fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource") - if fallback_url not in urls: - urls.append(fallback_url) + path = parsed.path.rstrip("/") + + # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp) + if path: + path_url = f"{base_url}/.well-known/oauth-protected-resource{path}" + if path_url not in urls: + urls.append(path_url) + + # Priority 3: At root (e.g., /.well-known/oauth-protected-resource) + root_url = f"{base_url}/.well-known/oauth-protected-resource" + if root_url not in urls: + urls.append(root_url) return urls @@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery. - Per RFC 8414 section 3: - - If issuer has no path: https://example.com/.well-known/oauth-authorization-server - - If issuer has path: https://example.com/.well-known/oauth-authorization-server{path} - - Example: - - issuer: https://example.com/oauth - - metadata: https://example.com/.well-known/oauth-authorization-server/oauth + Per RFC 8414 section 3.1 and section 5, try all possible endpoints: + - OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1 + - OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1 + - OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration + - OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server + - OpenID Connect at root: https://example.com/.well-known/openid-configuration """ urls = [] base_url = auth_server_url or server_url parsed = urlparse(base_url) base = f"{parsed.scheme}://{parsed.netloc}" - path = parsed.path.rstrip("/") # Remove trailing slash + path = parsed.path.rstrip("/") + # OAuth 2.0 Authorization Server Metadata at root (MCP-03-26) + urls.append(f"{base}/.well-known/oauth-authorization-server") - # Try OpenID Connect discovery first (more common) - urls.append(urljoin(base + "/", ".well-known/openid-configuration")) + # OpenID Connect Discovery at root + urls.append(f"{base}/.well-known/openid-configuration") - # OAuth 2.0 Authorization Server Metadata (RFC 8414) - # Include the path component if present in the issuer URL if path: - urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}")) - else: - urls.append(urljoin(base, ".well-known/oauth-authorization-server")) + # OpenID Connect Discovery with path insertion + urls.append(f"{base}/.well-known/openid-configuration{path}") + + # OpenID Connect Discovery path appending + urls.append(f"{base}{path}/.well-known/openid-configuration") + + # OAuth 2.0 Authorization Server Metadata with path insertion + urls.append(f"{base}/.well-known/oauth-authorization-server{path}") return urls diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index b0e0dab9be..2b0645b558 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -59,7 +59,7 @@ class MCPClient: try: logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) self.connect_server(sse_client, "sse") - except MCPConnectionError: + except (MCPConnectionError, ValueError): logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md index a6caa7eb1e..b9d2c55210 100644 --- a/api/core/model_runtime/README.md +++ b/api/core/model_runtime/README.md @@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model - Model provider display - ![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png) - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md). + Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - Selectable model list display - ![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png) - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - In addition, this list also returns configurable parameter information and rules for LLM, as shown below: - - ![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png) - - These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule). + In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - Provider/model credential authentication - ![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png) - - ![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png) - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO. + The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. ## Structure -![](./docs/en_US/images/index/image-20231210165243632.png) - Model Runtime is divided into three layers: - The outermost layer is the factory method @@ -60,9 +46,6 @@ Model Runtime is divided into three layers: It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). -## Next Steps +## Documentation -- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md) -- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel) -- View YAML configuration rules: [Link](./docs/en_US/schema.md) -- Implement interface methods: [Link](./docs/en_US/interfaces.md) +For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md index dfe614347a..0a8b56b3fe 100644 --- a/api/core/model_runtime/README_CN.md +++ b/api/core/model_runtime/README_CN.md @@ -18,34 +18,20 @@ - 模型供应商展示 - ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) - -​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。 + 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - 可选择的模型列表展示 - ![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png) + 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 -​ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - -​ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图: - -​ ![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png) - -​ 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。 + 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - 供应商/模型凭据鉴权 - ![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png) - -![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png) - -​ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。 + 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 ## 结构 -![](./docs/zh_Hans/images/index/image-20231210165243632.png) - Model Runtime 分三层: - 最外层为工厂方法 @@ -59,8 +45,7 @@ Model Runtime 分三层: 对于供应商/模型凭据,有两种情况 - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - ![Alt text](docs/zh_Hans/images/index/image.png) + - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 @@ -74,20 +59,6 @@ Model Runtime 分三层: - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 -## 下一步 +## 文档 -### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md) - -当添加后,这里将会出现一个新的供应商 - -![Alt text](docs/zh_Hans/images/index/image-1.png) - -### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B) - -当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。 - -![Alt text](docs/zh_Hans/images/index/image-2.png) - -### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) - -你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 +有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index a1c84bd5d9..7bb2749afa 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -39,7 +39,7 @@ from core.trigger.errors import ( plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( float | httpx.Timeout | None, - getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0), + getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0), ) plugin_daemon_request_timeout: httpx.Timeout | None if _plugin_daemon_timeout_config is None: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 044b118635..f67f613e9d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor): def _extract_images_from_docx(self, doc): image_count = 0 image_map = {} + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL for r_id, rel in doc.part.rels.items(): if "image" in rel.target_ref: @@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor): used_at=naive_utc_now(), ) db.session.add(upload_file) - # Use r_id as key for external images since target_part is undefined - image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" + image_map[r_id] = f"![image]({base_url}/files/{upload_file.id}/file-preview)" else: image_ext = rel.target_ref.split(".")[-1] if image_ext is None: @@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor): used_at=naive_utc_now(), ) db.session.add(upload_file) - # Use target_part as key for internal images - image_map[rel.target_part] = ( - f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" - ) + image_map[rel.target_part] = f"![image]({base_url}/files/{upload_file.id}/file-preview)" db.session.commit() return image_map diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 801d2a2a52..b65cb14d8e 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,6 +2,7 @@ from __future__ import annotations +import codecs import re from typing import Any @@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) - self._fixed_separator = fixed_separator + self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape") self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""] def split_text(self, text: str) -> list[str]: @@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = re.split(r" +", text) else: splits = text.split(separator) - splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] + if self._keep_separator: + splits = [s + separator for s in splits[:-1]] + splits[-1:] else: splits = list(text) if separator == "\n": @@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits - _separator = separator if self._keep_separator else "" + _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) if separator != "": for s, s_len in zip(splits, s_lens): diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index f0c84872fb..931c6113a7 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -86,6 +86,11 @@ class Executor: node_data.authorization.config.api_key = variable_pool.convert_template( node_data.authorization.config.api_key ).text + # Validate that API key is not empty after template conversion + if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): + raise AuthorizationConfigError( + "API key is required for authorization but was empty. Please provide a valid API key." + ) self.url = node_data.url self.method = node_data.method diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 38effa79f7..36fc5078c5 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,3 +1,4 @@ +import json from typing import Any from jsonschema import Draft7Validator, ValidationError @@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]): if value is None and variable.required: raise ValueError(f"{key} is required in input form") - if not isinstance(value, dict): - raise ValueError(f"{key} must be a JSON object") - schema = variable.json_schema if not schema: continue + if not value: + continue + try: - Draft7Validator(schema).validate(value) + json_schema = json.loads(schema) + except json.JSONDecodeError as e: + raise ValueError(f"{schema} must be a valid JSON object") + + try: + json_value = json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"{value} must be a valid JSON object") + + try: + Draft7Validator(json_schema).validate(json_value) except ValidationError as e: raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") - node_inputs[key] = value + node_inputs[key] = json_value diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6313085e64..5a69eb15ac 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} + +elif [[ "${MODE}" == "job" ]]; then + # Job mode: Run a one-time Flask command and exit + # Pass Flask command and arguments via container args + # Example K8s usage: + # args: + # - create-tenant + # - --email + # - admin@example.com + # + # Example Docker usage: + # docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com + + if [[ $# -eq 0 ]]; then + echo "Error: No command specified for job mode." + echo "" + echo "Usage examples:" + echo " Kubernetes:" + echo " args: [create-tenant, --email, admin@example.com]" + echo "" + echo " Docker:" + echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com" + echo "" + echo "Available commands:" + echo " create-tenant, reset-password, reset-email, upgrade-db," + echo " vdb-migrate, install-plugins, and more..." + echo "" + echo "Run 'flask --help' to see all available commands." + exit 1 + fi + + echo "Running Flask job command: flask $*" + + # Temporarily disable exit on error to capture exit code + set +e + flask "$@" + JOB_EXIT_CODE=$? + set -e + + if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then + echo "Job completed successfully." + else + echo "Job failed with exit code ${JOB_EXIT_CODE}." + fi + + exit ${JOB_EXIT_CODE} + else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index a084844d72..83c5c2d12f 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage): if not self.exists(path): raise FileNotFoundError("Path not found") - all_files = self.op.scan(path=path) + # Use the new OpenDAL 0.46.0+ API with recursive listing + lister = self.op.list(path, recursive=True) if files and directories: logger.debug("files and directories on %s scanned", path) - return [f.path for f in all_files] + return [entry.path for entry in lister] if files: logger.debug("files on %s scanned", path) - return [f.path for f in all_files if not f.path.endswith("/")] + return [entry.path for entry in lister if not entry.metadata.is_dir] elif directories: logger.debug("directories on %s scanned", path) - return [f.path for f in all_files if f.path.endswith("/")] + return [entry.path for entry in lister if entry.metadata.is_dir] else: raise ValueError("At least one of files or directories must be True") diff --git a/api/libs/helper.py b/api/libs/helper.py index 4a7afe0bda..74e1808e49 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast +from uuid import UUID from zoneinfo import available_timezones from flask import Response, stream_with_context @@ -119,6 +120,19 @@ def uuid_value(value: Any) -> str: raise ValueError(error) +def normalize_uuid(value: str | UUID) -> str: + if not value: + return "" + + try: + return uuid_value(value) + except ValueError as exc: + raise ValueError("must be a valid UUID") from exc + + +UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] + + def alphanumeric(value: str): # check if the value is alphanumeric and underlined if re.match(r"^[a-zA-Z0-9_]+$", value): diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54e1c9d285..3d7cb6cc8d 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,8 +1,12 @@ +import logging import os +from collections.abc import Sequence from typing import Literal import httpx +from pydantic import TypeAdapter from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed +from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError from enums.cloud_plan import CloudPlan @@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole +logger = logging.getLogger(__name__) + + +class SubscriptionPlan(TypedDict): + """Tenant subscriptionplan information.""" + + plan: str + expiration_date: int + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") @@ -239,3 +252,39 @@ class BillingService: def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str): payload = {"account_id": account_id, "click_id": click_id} return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload) + + @classmethod + def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Bulk fetch billing subscription plan via billing API. + Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) + Returns: + Mapping of tenant_id -> {plan: str, expiration_date: int} + """ + results: dict[str, SubscriptionPlan] = {} + subscription_adapter = TypeAdapter(SubscriptionPlan) + + chunk_size = 200 + for i in range(0, len(tenant_ids), chunk_size): + chunk = tenant_ids[i : i + chunk_size] + try: + resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) + data = resp.get("data", {}) + + for tenant_id, plan in data.items(): + subscription_plan = subscription_adapter.validate_python(plan) + results[tenant_id] = subscription_plan + except Exception: + logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + continue + + return results + + @classmethod + def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: + resp = cls._send_request("GET", "/subscription/cleanup/whitelist") + data = resp.get("data", []) + tenant_whitelist = [] + for item in data: + tenant_whitelist.append(item["tenant_id"]) + return tenant_whitelist diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index a97ccab914..cbb0efcc2a 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel): description: str icon_info: IconInfo permission: str - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None yaml_content: str | None = None diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index d641fe0315..252be77b27 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -15,7 +15,6 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache -from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError @@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel): return self.needs_validation and self.validation_passed and self.reconnect_result is not None +class ProviderUrlValidationData(BaseModel): + """Data required for URL validation, extracted from database to perform network operations outside of session""" + + current_server_url_hash: str + headers: dict[str, str] + timeout: float | None + sse_read_timeout: float | None + + class MCPToolManageService: """Service class for managing MCP tools and providers.""" @@ -166,9 +174,6 @@ class MCPToolManageService: self._session.add(mcp_tool) self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) return mcp_providers @@ -192,7 +197,7 @@ class MCPToolManageService: Update an MCP provider. Args: - validation_result: Pre-validation result from validate_server_url_change. + validation_result: Pre-validation result from validate_server_url_standalone. If provided and contains reconnect_result, it will be used instead of performing network operations. """ @@ -251,8 +256,6 @@ class MCPToolManageService: # Flush changes to database self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) except IntegrityError as e: self._handle_integrity_error(e, name, server_url, server_identifier) @@ -261,9 +264,6 @@ class MCPToolManageService: mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - def list_providers( self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True ) -> list[ToolProviderApiEntity]: @@ -546,30 +546,39 @@ class MCPToolManageService: ) return self.execute_auth_actions(auth_result) - def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: - """Attempt to reconnect to MCP provider with new server URL.""" + def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData: + """ + Get provider data required for URL validation. + This method performs database read and should be called within a session. + + Returns: + ProviderUrlValidationData: Data needed for standalone URL validation + """ + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) provider_entity = provider.to_entity() - headers = provider_entity.headers + return ProviderUrlValidationData( + current_server_url_hash=provider.server_url_hash, + headers=provider_entity.headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + ) - try: - tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) - return ReconnectResult( - authed=True, - tools=json.dumps([tool.model_dump() for tool in tools]), - encrypted_credentials=EMPTY_CREDENTIALS_JSON, - ) - except MCPAuthError: - return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) - except MCPError as e: - raise ValueError(f"Failed to re-connect MCP server: {e}") from e - - def validate_server_url_change( - self, *, tenant_id: str, provider_id: str, new_server_url: str + @staticmethod + def validate_server_url_standalone( + *, + tenant_id: str, + new_server_url: str, + validation_data: ProviderUrlValidationData, ) -> ServerUrlValidationResult: """ Validate server URL change by attempting to connect to the new server. - This method should be called BEFORE update_provider to perform network operations - outside of the database transaction. + This method performs network operations and MUST be called OUTSIDE of any database session + to avoid holding locks during network I/O. + + Args: + tenant_id: Tenant ID for encryption + new_server_url: The new server URL to validate + validation_data: Provider data obtained from get_provider_for_url_validation Returns: ServerUrlValidationResult: Validation result with connection status and tools if successful @@ -579,25 +588,30 @@ class MCPToolManageService: return ServerUrlValidationResult(needs_validation=False) # Validate URL format - if not self._is_valid_url(new_server_url): + parsed = urlparse(new_server_url) + if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]: raise ValueError("Server URL is not valid.") # Always encrypt and hash the URL encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() - # Get current provider - provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - # Check if URL is actually different - if new_server_url_hash == provider.server_url_hash: + if new_server_url_hash == validation_data.current_server_url_hash: # URL hasn't changed, but still return the encrypted data return ServerUrlValidationResult( - needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash + needs_validation=False, + encrypted_server_url=encrypted_server_url, + server_url_hash=new_server_url_hash, ) - # Perform validation by attempting to connect - reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) + # Perform network validation - this is the expensive operation that should be outside session + reconnect_result = MCPToolManageService._reconnect_with_url( + server_url=new_server_url, + headers=validation_data.headers, + timeout=validation_data.timeout, + sse_read_timeout=validation_data.sse_read_timeout, + ) return ServerUrlValidationResult( needs_validation=True, validation_passed=True, @@ -606,6 +620,38 @@ class MCPToolManageService: server_url_hash=new_server_url_hash, ) + @staticmethod + def _reconnect_with_url( + *, + server_url: str, + headers: dict[str, str], + timeout: float | None, + sse_read_timeout: float | None, + ) -> ReconnectResult: + """ + Attempt to connect to MCP server with given URL. + This is a static method that performs network I/O without database access. + """ + from core.mcp.mcp_client import MCPClient + + try: + with MCPClient( + server_url=server_url, + headers=headers, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + ) as mcp_client: + tools = mcp_client.list_tools() + return ReconnectResult( + authed=True, + tools=json.dumps([tool.model_dump() for tool in tools]), + encrypted_credentials=EMPTY_CREDENTIALS_JSON, + ) + except MCPAuthError: + return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) + except MCPError as e: + raise ValueError(f"Failed to re-connect MCP server: {e}") from e + def _build_tool_provider_response( self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list ) -> ToolProviderApiEntity: diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 4c1f38c3bb..5fc2597c92 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -2,7 +2,6 @@ import logging import time import click -import sqlalchemy as sa from celery import shared_task from sqlalchemy import select @@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment -from models.source import DataSourceOauthBinding +from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): page_id = data_source_info["notion_page_id"] page_type = data_source_info["type"] page_edited_time = data_source_info["last_edited_time"] + credential_id = data_source_info.get("credential_id") - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .where( - sa.and_( - DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ) - .first() + # Get credentials from datasource provider + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_datasource_credentials( + tenant_id=document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", ) - if not data_source_binding: - raise ValueError("Data source binding not found.") + + if not credential: + logger.error( + "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", + document_id, + document.tenant_id, + credential_id, + ) + document.indexing_status = "error" + document.error = "Datasource credential not found. Please reconnect your Notion workspace." + document.stopped_at = naive_utc_now() + db.session.commit() + db.session.close() + return loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, - notion_access_token=data_source_binding.access_token, + notion_access_token=credential.get("integration_secret"), tenant_id=document.tenant_id, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index e75258a2a2..d814da8ec7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -6,6 +6,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities import GraphInitParams +from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.http_request.node import HttpRequestNode from core.workflow.nodes.node_factory import DifyNodeFactory @@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) -def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): - """Test: In custom authentication mode, when the api_key is empty, no header should be set.""" +def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): + """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" from core.workflow.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) + from core.workflow.nodes.http_request.exc import AuthorizationConfigError from core.workflow.nodes.http_request.executor import Executor from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): ssl_verify=True, ) - # Create executor - executor = Executor( - node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool - ) - - # Get assembled headers - headers = executor._assembling_headers() - - # When api_key is empty, the custom header should NOT be set - assert "X-Custom-Auth" not in headers + # Create executor should raise AuthorizationConfigError + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), + variable_pool=variable_pool, + ) @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) @@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_with_empty_api_key(setup_http_mock): """ - Test that custom authorization doesn't set header when api_key is empty. - This test verifies the fix for issue #23554. + Test that custom authorization raises error when api_key is empty. + This test verifies the fix for issue #21830. """ + node = init_http_node( config={ "id": "1", @@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock): ) result = node._run() - assert result.process_data is not None - data = result.process_data.get("request", "") - - # Custom header should NOT be set when api_key is empty - assert "X-Custom-Auth:" not in data + # Should fail with AuthorizationConfigError + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "API key is required" in result.error + assert result.error_type == "AuthorizationConfigError" @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 0871467a05..2ff71ea6ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -2,7 +2,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import TypeAdapter, ValidationError +from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -298,7 +300,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -364,7 +366,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none"} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -428,21 +430,10 @@ class TestApiToolManageService: labels = ["test"] # Act & Assert: Try to create provider with invalid schema type - with pytest.raises(ValueError) as exc_info: - ApiToolManageService.create_api_tool_provider( - user_id=account.id, - tenant_id=tenant.id, - provider_name=provider_name, - icon=icon, - credentials=credentials, - schema_type=schema_type, - schema=schema, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - labels=labels, - ) + with pytest.raises(ValidationError) as exc_info: + TypeAdapter(ApiProviderSchemaType).validate_python(schema_type) - assert "invalid schema type" in str(exc_info.value) + assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies @@ -464,7 +455,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {} # Missing auth_type - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -507,7 +498,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔑"} credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 8c190762cf..6cae83ac37 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1308,18 +1308,17 @@ class TestMCPToolManageService: type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(), ] - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: # Setup mock client mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1337,8 +1336,12 @@ class TestMCPToolManageService: assert tools_data[1]["name"] == "test_tool_2" # Verify mock interactions - provider_entity = mcp_provider.to_entity() - mock_mcp_client.assert_called_once() + mock_mcp_client.assert_called_once_with( + server_url="https://example.com/mcp", + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, + ) def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -1361,19 +1364,18 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise authentication error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPAuthError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1404,18 +1406,17 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise connection error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"): - service._reconnect_provider( + MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py new file mode 100644 index 0000000000..40f58c9ddf --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -0,0 +1,420 @@ +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueuePingEvent, +) +from core.app.entities.task_entities import ( + EasyUITaskState, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageFileStreamResponse, + MessageReplaceStreamResponse, + MessageStreamResponse, + PingStreamResponse, + StreamEvent, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AppGeneratorTTSPublisher +from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse: + """Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock(spec=ChatAppGenerateEntity) + entity.task_id = "test-task-id" + entity.app_id = "test-app-id" + # minimal app_config used by pipeline internals + entity.app_config = SimpleNamespace( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.CHAT, + app_model_config_dict={}, + additional_features=None, + sensitive_word_avoidance=None, + ) + # minimal model_conf for LLMResult init + entity.model_conf = SimpleNamespace( + model="test-model", + provider_model_bundle=SimpleNamespace(model_type_instance=Mock()), + credentials={}, + ) + return entity + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = Mock(spec=AppQueueManager) + return manager + + @pytest.fixture + def mock_message_cycle_manager(self): + """Create a mock message cycle manager.""" + manager = Mock() + manager.get_message_event_type.return_value = StreamEvent.MESSAGE + manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse) + manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse) + manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse) + manager.handle_retriever_resources = Mock() + manager.handle_annotation_reply.return_value = None + return manager + + @pytest.fixture + def mock_conversation(self): + """Create a mock conversation.""" + conversation = Mock() + conversation.id = "test-conversation-id" + conversation.mode = "chat" + return conversation + + @pytest.fixture + def mock_message(self): + """Create a mock message.""" + message = Mock() + message.id = "test-message-id" + message.created_at = Mock() + message.created_at.timestamp.return_value = 1234567890 + return message + + @pytest.fixture + def mock_task_state(self): + """Create a mock task state.""" + task_state = Mock(spec=EasyUITaskState) + + # Create LLM result mock + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.prompt_messages = [] + llm_result.message = Mock() + llm_result.message.content = "" + + task_state.llm_result = llm_result + task_state.answer = "" + + return task_state + + @pytest.fixture + def pipeline( + self, + mock_application_generate_entity, + mock_queue_manager, + mock_conversation, + mock_message, + mock_message_cycle_manager, + mock_task_state, + ): + """Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies.""" + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state + ): + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=mock_application_generate_entity, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + stream=True, + ) + pipeline._message_cycle_manager = mock_message_cycle_manager + pipeline._task_state = mock_task_state + return pipeline + + def test_get_message_event_type_called_once_when_first_llm_chunk_arrives( + self, pipeline, mock_message_cycle_manager + ): + """Expect get_message_event_type to be called when processing the first LLM chunk event.""" + # Setup a minimal LLM chunk event + chunk = Mock() + chunk.delta.message.content = "hi" + chunk.prompt_messages = [] + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id") + + def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with text content.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Hello, world!" + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello, world!" + + def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with list content.""" + # Setup + text_content = Mock(spec=TextPromptMessageContent) + text_content.data = "Hello" + + chunk = Mock() + chunk.delta.message.content = [text_content, " world!"] + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello world!" + + def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of agent message events.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Agent response" + + agent_message_event = Mock(spec=QueueAgentMessageEvent) + agent_message_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = agent_message_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Ensure method under assertion is a mock to track calls + pipeline._agent_message_to_stream_response = Mock(return_value=Mock()) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + # Agent messages should use _agent_message_to_stream_response + pipeline._agent_message_to_stream_response.assert_called_once_with( + answer="Agent response", message_id="test-message-id" + ) + + def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of message end events.""" + # Setup + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.message = Mock() + llm_result.message.content = "Final response" + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = llm_result + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert mock_task_state.llm_result == llm_result + pipeline._save_message.assert_called_once() + pipeline._message_end_to_stream_response.assert_called_once() + + def test_error_event(self, pipeline): + """Test handling of error events.""" + # Setup + error_event = Mock(spec=QueueErrorEvent) + error_event.error = Exception("Test error") + + mock_queue_message = Mock() + mock_queue_message.event = error_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.handle_error = Mock(return_value=Exception("Test error")) + pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.handle_error.assert_called_once() + pipeline.error_to_stream_response.assert_called_once() + + def test_ping_event(self, pipeline): + """Test handling of ping events.""" + # Setup + ping_event = Mock(spec=QueuePingEvent) + + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.ping_stream_response.assert_called_once() + + def test_file_event(self, pipeline, mock_message_cycle_manager): + """Test handling of file events.""" + # Setup + file_event = Mock(spec=QueueMessageFileEvent) + file_event.message_file_id = "file-id" + + mock_queue_message = Mock() + mock_queue_message.event = file_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + file_response = Mock(spec=MessageFileStreamResponse) + mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert responses[0] == file_response + mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event) + + def test_publisher_is_called_with_messages(self, pipeline): + """Test that publisher publishes messages when provided.""" + # Setup + publisher = Mock(spec=AppGeneratorTTSPublisher) + + ping_event = Mock(spec=QueuePingEvent) + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + list(pipeline._process_stream_response(publisher=publisher, trace_manager=None)) + + # Assert + # Called once with message and once with None at the end + assert publisher.publish.call_count == 2 + publisher.publish.assert_any_call(mock_queue_message) + publisher.publish.assert_any_call(None) + + def test_trace_manager_passed_to_save_message(self, pipeline): + """Test that trace manager is passed to _save_message.""" + # Setup + trace_manager = Mock(spec=TraceQueueManager) + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = None + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager)) + + # Assert + pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager) + + def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling multiple events in sequence.""" + # Setup + chunk1 = Mock() + chunk1.delta.message.content = "Hello" + chunk1.prompt_messages = [] + + chunk2 = Mock() + chunk2.delta.message.content = " world!" + chunk2.prompt_messages = [] + + llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event1.chunk = chunk1 + + ping_event = Mock(spec=QueuePingEvent) + + llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event2.chunk = chunk2 + + mock_queue_messages = [ + Mock(event=llm_chunk_event1), + Mock(event=ping_event), + Mock(event=llm_chunk_event2), + ] + pipeline.queue_manager.listen.return_value = mock_queue_messages + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 3 + assert mock_task_state.llm_result.message.content == "Hello world!" + + # Verify calls to message_to_stream_response + assert mock_message_cycle_manager.message_to_stream_response.call_count == 2 + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py new file mode 100644 index 0000000000..5ef7f0d7f4 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -0,0 +1,166 @@ +"""Unit tests for the message cycle manager optimization.""" + +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest +from flask import current_app + +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager + + +class TestMessageCycleManagerOptimization: + """Test cases for the message cycle manager optimization that prevents N+1 queries.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock() + entity.task_id = "test-task-id" + return entity + + @pytest.fixture + def message_cycle_manager(self, mock_application_generate_entity): + """Create a message cycle manager instance.""" + task_state = Mock() + return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state) + + def test_get_message_event_type_with_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE_FILE when message has files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_get_message_event_type_without_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE when message has no files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and no message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): + """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute: compute event type once, then pass to message_to_stream_response + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=event_type + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): + """Test that message_to_stream_response skips database query when event_type is provided.""" + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + # Execute with event_type provided + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE + # Should not query database when event_type is provided + mock_session_class.assert_not_called() + + def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): + """Test message_to_stream_response with from_variable_selector parameter.""" + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", + message_id="test-message-id", + from_variable_selector=["var1", "var2"], + event_type=StreamEvent.MESSAGE, + ) + + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.from_variable_selector == ["var1", "var2"] + assert result.event == StreamEvent.MESSAGE + + def test_optimization_usage_example(self, message_cycle_manager): + """Test the optimization pattern that should be used by callers.""" + # Step 1: Get event type once (this queries database) + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None # No files + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + + # Should query database once + mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + assert event_type == StreamEvent.MESSAGE + + # Step 2: Use event_type for multiple calls (no additional queries) + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + mock_session_class.return_value.__enter__.return_value = Mock() + + chunk1_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 1", message_id="test-message-id", event_type=event_type + ) + + chunk2_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 2", message_id="test-message-id", event_type=event_type + ) + + # Should not query database again + mock_session_class.assert_not_called() + + assert chunk1_response.event == StreamEvent.MESSAGE + assert chunk2_response.event == StreamEvent.MESSAGE + assert chunk1_response.answer == "Chunk 1" + assert chunk2_response.answer == "Chunk 2" diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index fd0b0e2e44..3203aab8c3 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch): # DB interactions should be recorded assert len(db_stub.session.added) == 2 assert db_stub.session.committed is True + + +def test_extract_images_from_docx_uses_internal_files_url(): + """Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access.""" + # Test the URL generation logic directly + from configs import dify_config + + # Mock the configuration values + original_files_url = getattr(dify_config, "FILES_URL", None) + original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None) + + try: + # Set both URLs - INTERNAL should take precedence + dify_config.FILES_URL = "http://external.example.com" + dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001" + + # Test the URL generation logic (same as in word_extractor.py) + upload_file_id = "test_file_id" + + # This is the pattern we fixed in the word extractor + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + generated_url = f"{base_url}/files/{upload_file_id}/file-preview" + + # Verify that INTERNAL_FILES_URL is used instead of FILES_URL + assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}" + assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}" + + finally: + # Restore original values + if original_files_url is not None: + dify_config.FILES_URL = original_files_url + if original_internal_files_url is not None: + dify_config.INTERNAL_FILES_URL = original_internal_files_url diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 7d246ac3cc..943a9e5712 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter: # Verify no empty chunks assert all(len(chunk) > 0 for chunk in result) + def test_double_slash_n(self): + data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2." + separator = "\\n\\n---\\n\\n" + splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator) + chunks = splitter.split_text(data) + assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index f040a92b6f..27df938102 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,3 +1,5 @@ +import pytest + from core.workflow.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, @@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeData, ) from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.exc import AuthorizationConfigError from core.workflow.nodes.http_request.executor import Executor from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -348,3 +351,127 @@ def test_init_params(): executor = create_executor("key1:value1\n\nkey2:value2\n\n") executor._init_params() assert executor.params == [("key1", "value1"), ("key2", "value2")] + + +def test_empty_api_key_raises_error_bearer(): + """Test that empty API key raises AuthorizationConfigError for bearer auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": ""}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_empty_api_key_raises_error_basic(): + """Test that empty API key raises AuthorizationConfigError for basic auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "basic", "api_key": ""}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_empty_api_key_raises_error_custom(): + """Test that empty API key raises AuthorizationConfigError for custom auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_whitespace_only_api_key_raises_error(): + """Test that whitespace-only API key raises AuthorizationConfigError.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": " "}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_valid_api_key_works(): + """Test that valid API key works correctly for bearer auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": "valid-api-key-123"}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + executor = Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + # Should not raise an error + headers = executor._assembling_headers() + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer valid-api-key-123" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 83799c9508..539e72edb5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -1,3 +1,4 @@ +import json import time import pytest @@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables): def test_json_object_valid_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age"], + } + ) variables = [ VariableEntity( @@ -65,7 +68,7 @@ def test_json_object_valid_schema(): ) ] - user_inputs = {"profile": {"age": 20, "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})} node = make_start_node(user_inputs, variables) result = node._run() @@ -74,12 +77,23 @@ def test_json_object_valid_schema(): def test_json_object_invalid_json_string(): + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, required=True, + json_schema=schema, ) ] @@ -88,38 +102,21 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="profile must be a JSON object"): - node._run() - - -@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"]) -def test_json_object_valid_json_but_not_object(value): - variables = [ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ] - - user_inputs = {"profile": value} - - node = make_start_node(user_inputs, variables) - - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'): node._run() def test_json_object_does_not_match_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema(): ] # age is a string, which violates the schema (expects number) - user_inputs = {"profile": {"age": "twenty", "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})} node = make_start_node(user_inputs, variables) @@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema(): def test_json_object_missing_required_schema_field(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field(): ] # Missing required field "name" - user_inputs = {"profile": {"age": 20}} + user_inputs = {"profile": json.dumps({"age": 20})} node = make_start_node(user_inputs, variables) @@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided(): variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, - required=False, + required=True, ) ] @@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided(): node = make_start_node(user_inputs, variables) # Current implementation raises a validation error even when the variable is optional - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match="profile is required in input form"): node._run() diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 915aee3fa7..f50f744a75 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases: assert "Only team owner or team admin can perform this action" in str(exc_info.value) +class TestBillingServiceSubscriptionOperations: + """Unit tests for subscription operations in BillingService. + + Tests cover: + - Bulk plan retrieval with chunking + - Expired subscription cleanup whitelist retrieval + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_plan_bulk_with_empty_list(self, mock_send_request): + """Test bulk plan retrieval with empty tenant list.""" + # Arrange + tenant_ids = [] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + mock_send_request.assert_not_called() + + def test_get_plan_bulk_with_chunking(self, mock_send_request): + """Test bulk plan retrieval with more than 200 tenants (chunking logic).""" + # Arrange - 250 tenants to test chunking (chunk_size = 200) + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk: tenants 0-199 + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk: tenants 200-249 + second_chunk_response = { + "data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)} + } + + mock_send_request.side_effect = [first_chunk_response, second_chunk_response] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 250 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert result["tenant-200"]["plan"] == "professional" + assert result["tenant-249"]["plan"] == "professional" + assert mock_send_request.call_count == 2 + + # Verify first chunk call + first_call = mock_send_request.call_args_list[0] + assert first_call[0][0] == "POST" + assert first_call[0][1] == "/subscription/plan/batch" + assert len(first_call[1]["json"]["tenant_ids"]) == 200 + + # Verify second chunk call + second_call = mock_send_request.call_args_list[1] + assert len(second_call[1]["json"]["tenant_ids"]) == 50 + + def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request): + """Test bulk plan retrieval when one batch fails but others succeed.""" + # Arrange - 250 tenants, second batch will fail + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk succeeds + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk fails - need to create a mock that raises when called + def side_effect_func(*args, **kwargs): + if mock_send_request.call_count == 1: + return first_chunk_response + else: + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should only have data from first batch + assert len(result) == 200 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert "tenant-200" not in result + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request): + """Test bulk plan retrieval when all batches fail.""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # All chunks fail + def side_effect_func(*args, **kwargs): + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should return empty dict + assert result == {} + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request): + """Test bulk plan retrieval with exactly 200 tenants (boundary condition).""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(200)] + mock_send_request.return_value = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 200 + assert mock_send_request.call_count == 1 + + def test_get_plan_bulk_with_empty_data_response(self, mock_send_request): + """Test bulk plan retrieval with empty data in response.""" + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + mock_send_request.return_value = {"data": {}} + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + + def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): + """Test successful retrieval of expired subscription cleanup whitelist.""" + # Arrange + api_response = [ + { + "created_at": "2025-10-16T01:56:17", + "tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", + "contact": "example@dify.ai", + "id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5", + "expired_at": "2026-01-01T01:56:17", + "updated_at": "2025-10-16T01:56:17", + }, + { + "created_at": "2025-10-16T02:00:00", + "tenant_id": "tenant-2", + "contact": "test@example.com", + "id": "whitelist-id-2", + "expired_at": "2026-02-01T00:00:00", + "updated_at": "2025-10-16T02:00:00", + }, + { + "created_at": "2025-10-16T03:00:00", + "tenant_id": "tenant-3", + "contact": "another@example.com", + "id": "whitelist-id-3", + "expired_at": "2026-03-01T00:00:00", + "updated_at": "2025-10-16T03:00:00", + }, + ] + mock_send_request.return_value = {"data": api_response} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert - should return only tenant_ids + assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"] + assert len(result) == 3 + assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6" + assert result[1] == "tenant-2" + assert result[2] == "tenant-3" + mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist") + + def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request): + """Test retrieval of empty cleanup whitelist.""" + # Arrange + mock_send_request.return_value = {"data": []} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert + assert result == [] + assert len(result) == 0 + + class TestBillingServiceIntegrationScenarios: """Integration-style tests simulating real-world usage scenarios. diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py new file mode 100644 index 0000000000..374abe0368 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -0,0 +1,520 @@ +""" +Unit tests for document indexing sync task. + +This module tests the document indexing sync task functionality including: +- Syncing Notion documents when updated +- Validating document and data source existence +- Credential validation and retrieval +- Cleaning old segments before re-indexing +- Error handling and edge cases +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from models.dataset import Dataset, Document, DocumentSegment +from tasks.document_indexing_sync_task import document_indexing_sync_task + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def document_id(): + """Generate a unique document ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def notion_workspace_id(): + """Generate a Notion workspace ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def notion_page_id(): + """Generate a Notion page ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def credential_id(): + """Generate a credential ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id): + """Create a mock Document object with Notion data source.""" + doc = Mock(spec=Document) + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.data_source_type = "notion_import" + doc.indexing_status = "completed" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + doc.doc_form = "text_model" + doc.data_source_info_dict = { + "notion_workspace_id": notion_workspace_id, + "notion_page_id": notion_page_id, + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + "credential_id": credential_id, + } + return doc + + +@pytest.fixture +def mock_document_segments(document_id): + """Create mock DocumentSegment objects.""" + segments = [] + for i in range(3): + segment = Mock(spec=DocumentSegment) + segment.id = str(uuid.uuid4()) + segment.document_id = document_id + segment.index_node_id = f"node-{document_id}-{i}" + segments.append(segment) + return segments + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("tasks.document_indexing_sync_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_session.scalars.return_value = MagicMock() + yield mock_session + + +@pytest.fixture +def mock_datasource_provider_service(): + """Mock DatasourceProviderService.""" + with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class: + mock_service = MagicMock() + mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"} + mock_service_class.return_value = mock_service + yield mock_service + + +@pytest.fixture +def mock_notion_extractor(): + """Mock NotionExtractor.""" + with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: + mock_extractor = MagicMock() + mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time + mock_extractor_class.return_value = mock_extractor + yield mock_extractor + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean = Mock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + yield mock_factory + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner.""" + with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock(spec=IndexingRunner) + mock_runner.run = Mock() + mock_runner_class.return_value = mock_runner + yield mock_runner + + +# ============================================================================ +# Tests for document_indexing_sync_task +# ============================================================================ + + +class TestDocumentIndexingSyncTask: + """Tests for the document_indexing_sync_task function.""" + + def test_document_not_found(self, mock_db_session, dataset_id, document_id): + """Test that task handles document not found gracefully.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_db_session.close.assert_called_once() + + def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when notion_workspace_id is missing.""" + # Arrange + mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"} + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when notion_page_id is missing.""" + # Arrange + mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"} + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when data_source_info is empty.""" + # Arrange + mock_document.data_source_info_dict = None + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_credential_not_found( + self, + mock_db_session, + mock_datasource_provider_service, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles missing credentials by updating document status.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_datasource_provider_service.get_datasource_credentials.return_value = None + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + assert mock_document.indexing_status == "error" + assert "Datasource credential not found" in mock_document.error + assert mock_document.stopped_at is not None + mock_db_session.commit.assert_called() + mock_db_session.close.assert_called() + + def test_page_not_updated( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task does nothing when page has not been updated.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + # Return same time as stored in document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Document status should remain unchanged + assert mock_document.indexing_status == "completed" + # No session operations should be performed beyond the initial query + mock_db_session.close.assert_not_called() + + def test_successful_sync_when_page_updated( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test successful sync flow when Notion page has been updated.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + # NotionExtractor returns updated time + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Verify document status was updated to parsing + assert mock_document.indexing_status == "parsing" + assert mock_document.processing_started_at is not None + + # Verify segments were cleaned + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + mock_processor.clean.assert_called_once() + + # Verify segments were deleted from database + for segment in mock_document_segments: + mock_db_session.delete.assert_any_call(segment) + + # Verify indexing runner was called + mock_indexing_runner.run.assert_called_once_with([mock_document]) + + # Verify session operations + assert mock_db_session.commit.called + mock_db_session.close.assert_called_once() + + def test_dataset_not_found_during_cleaning( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles dataset not found during cleaning phase.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None] + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Document should still be set to parsing + assert mock_document.indexing_status == "parsing" + # Session should be closed after error + mock_db_session.close.assert_called_once() + + def test_cleaning_error_continues_to_indexing( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + dataset_id, + document_id, + ): + """Test that indexing continues even if cleaning fails.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error") + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Indexing should still be attempted despite cleaning error + mock_indexing_runner.run.assert_called_once_with([mock_document]) + mock_db_session.close.assert_called_once() + + def test_indexing_runner_document_paused_error( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that DocumentIsPausedError is handled gracefully.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Session should be closed after handling error + mock_db_session.close.assert_called_once() + + def test_indexing_runner_general_error( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that general exceptions during indexing are handled.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + mock_indexing_runner.run.side_effect = Exception("Indexing error") + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Session should be closed after error + mock_db_session.close.assert_called_once() + + def test_notion_extractor_initialized_with_correct_params( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + notion_workspace_id, + notion_page_id, + ): + """Test that NotionExtractor is initialized with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update + + # Act + with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: + mock_extractor = MagicMock() + mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + mock_extractor_class.return_value = mock_extractor + + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_extractor_class.assert_called_once_with( + notion_workspace_id=notion_workspace_id, + notion_obj_id=notion_page_id, + notion_page_type="page", + notion_access_token="test_token", + tenant_id=mock_document.tenant_id, + ) + + def test_datasource_credentials_requested_correctly( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + credential_id, + ): + """Test that datasource credentials are requested with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with( + tenant_id=mock_document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) + + def test_credential_id_missing_uses_none( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles missing credential_id by passing None.""" + # Arrange + mock_document.data_source_info_dict = { + "notion_workspace_id": "ws123", + "notion_page_id": "page123", + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + } + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with( + tenant_id=mock_document.tenant_id, + credential_id=None, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) + + def test_index_processor_clean_called_with_correct_params( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that index processor clean is called with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + expected_node_ids = [seg.index_node_id for seg in mock_document_segments] + mock_processor.clean.assert_called_once_with( + mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True + ) diff --git a/dev/start-worker b/dev/start-worker index a01da11d86..7876620188 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -37,6 +37,7 @@ show_help() { echo " pipeline - Standard pipeline tasks" echo " triggered_workflow_dispatcher - Trigger dispatcher tasks" echo " trigger_refresh_executor - Trigger refresh tasks" + echo " retention - Retention tasks" } # Parse command line arguments @@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then # Configure queues based on edition if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset and workflow have separate queues - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi echo "No queues specified, using edition-based defaults: ${QUEUES}" diff --git a/docker/.env.example b/docker/.env.example index dd0d083da3..e5cdb64dae 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1369,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 +# API side timeout (configure to match the Plugin Daemon side above) +PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1479,4 +1482,9 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 ANNOTATION_IMPORT_MAX_CONCURRENT=5 # The API key of amplitude -AMPLITUDE_API_KEY= \ No newline at end of file +AMPLITUDE_API_KEY= + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 4f6194b9e4..a07ed9e8ad 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -34,6 +34,7 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: init_permissions: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index aca4325880..24e1077ebe 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -591,6 +591,7 @@ x-shared-env: &shared-api-worker-env PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} @@ -663,6 +664,9 @@ x-shared-env: &shared-api-worker-env ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20} ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5} AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} services: # Init container to fix permissions @@ -699,6 +703,7 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: init_permissions: diff --git a/web/__mocks__/ky.ts b/web/__mocks__/ky.ts new file mode 100644 index 0000000000..6c7691f2cf --- /dev/null +++ b/web/__mocks__/ky.ts @@ -0,0 +1,71 @@ +/** + * Mock for ky HTTP client + * This mock is used to avoid ESM issues in Jest tests + */ + +type KyResponse = { + ok: boolean + status: number + statusText: string + headers: Headers + json: jest.Mock + text: jest.Mock + blob: jest.Mock + arrayBuffer: jest.Mock + clone: jest.Mock +} + +type KyInstance = jest.Mock & { + get: jest.Mock + post: jest.Mock + put: jest.Mock + patch: jest.Mock + delete: jest.Mock + head: jest.Mock + create: jest.Mock + extend: jest.Mock + stop: symbol +} + +const createResponse = (data: unknown = {}, status = 200): KyResponse => { + const response: KyResponse = { + ok: status >= 200 && status < 300, + status, + statusText: status === 200 ? 'OK' : 'Error', + headers: new Headers(), + json: jest.fn().mockResolvedValue(data), + text: jest.fn().mockResolvedValue(JSON.stringify(data)), + blob: jest.fn().mockResolvedValue(new Blob()), + arrayBuffer: jest.fn().mockResolvedValue(new ArrayBuffer(0)), + clone: jest.fn(), + } + // Ensure clone returns a new response-like object, not the same instance + response.clone.mockImplementation(() => createResponse(data, status)) + return response +} + +const createKyInstance = (): KyInstance => { + const instance = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) as KyInstance + + // HTTP methods + instance.get = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.post = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.put = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.patch = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.delete = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.head = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + + // Create new instance with custom options + instance.create = jest.fn().mockImplementation(() => createKyInstance()) + instance.extend = jest.fn().mockImplementation(() => createKyInstance()) + + // Stop method for AbortController + instance.stop = Symbol('stop') + + return instance +} + +const ky = createKyInstance() + +export default ky +export { ky } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 1f836de6e6..d5e3c61932 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -16,7 +16,7 @@ import { import { useTranslation } from 'react-i18next' import { useShallow } from 'zustand/react/shallow' import s from './style.module.css' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx index 2bfdece433..dda5dff2b9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx @@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react' import type { Dayjs } from 'dayjs' import type { FC } from 'react' import React, { useCallback } from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { formatToLocalTime } from '@/utils/format' import { useI18N } from '@/context/i18n' import Picker from '@/app/components/base/date-and-time-picker/date-picker' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx index f99ea52492..0a80bf670d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx @@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select' import type { Item } from '@/app/components/base/select' import dayjs from 'dayjs' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useTranslation } from 'react-i18next' const today = dayjs() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 246a1eb6a3..17c919bf22 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react' import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 628eb13071..767ccb8c59 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const I18N_PREFIX = 'app.tracing' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx index eecd356e08..e170159e35 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Input from '@/app/components/base/input' type Props = { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 2c17931b83..319ff3f423 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS import { TracingProvider } from './type' import TracingIcon from './tracing-icon' import ConfigButton from './config-button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing' import Indicator from '@/app/components/header/indicator' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index ac1704d60d..0779689c76 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -6,7 +6,7 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { TracingProvider } from './type' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing' import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx index ec9117dd38..aeca1cd3ab 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing' type Props = { diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 3effb79f20..3581587b54 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use import useDocumentTitle from '@/hooks/use-document-title' import ExtraInfo from '@/app/components/datasets/extra-info' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IAppDetailLayoutProps = { children: React.ReactNode diff --git a/web/app/(shareLayout)/webapp-reset-password/layout.tsx b/web/app/(shareLayout)/webapp-reset-password/layout.tsx index e0ac6b9ad6..13073b0e6a 100644 --- a/web/app/(shareLayout)/webapp-reset-password/layout.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' export default function SignInLayout({ children }: any) { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 5e3f6fff1d..843f10e039 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -2,7 +2,7 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter, useSearchParams } from 'next/navigation' -import cn from 'classnames' +import { cn } from '@/utils/classnames' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' import Button from '@/app/components/base/button' diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx index 7649982072..c75f925d40 100644 --- a/web/app/(shareLayout)/webapp-signin/layout.tsx +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -1,6 +1,6 @@ 'use client' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import type { PropsWithChildren } from 'react' diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index 219722eef3..a14bfcd737 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading' import MailAndCodeAuth from './components/mail-and-code-auth' import MailAndPasswordAuth from './components/mail-and-password-auth' import SSOAuth from './components/sso-auth' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { LicenseStatus } from '@/types/feature' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 2ab676d6b6..b70ab210d0 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { AppContextProvider } from '@/context/app-context' diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index d9d07cbfa1..11fc4866f3 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,13 +1,13 @@ 'use client' +import { useEffect } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' import { useRouter, useSearchParams } from 'next/navigation' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Button from '@/app/components/base/button' -import { invitationCheck } from '@/service/common' import Loading from '@/app/components/base/loading' import useDocumentTitle from '@/hooks/use-document-title' +import { useInvitationCheck } from '@/service/use-common' const ActivateForm = () => { useDocumentTitle('') @@ -26,19 +26,21 @@ const ActivateForm = () => { token, }, } - const { data: checkRes } = useSWR(checkParams, invitationCheck, { - revalidateOnFocus: false, - onSuccess(data) { - if (data.is_valid) { - const params = new URLSearchParams(searchParams) - const { email, workspace_id } = data.data - params.set('email', encodeURIComponent(email)) - params.set('workspace_id', encodeURIComponent(workspace_id)) - params.set('invite_token', encodeURIComponent(token as string)) - router.replace(`/signin?${params.toString()}`) - } - }, - }) + const { data: checkRes } = useInvitationCheck({ + ...checkParams.params, + token: token || undefined, + }, true) + + useEffect(() => { + if (checkRes?.is_valid) { + const params = new URLSearchParams(searchParams) + const { email, workspace_id } = checkRes.data + params.set('email', encodeURIComponent(email)) + params.set('workspace_id', encodeURIComponent(workspace_id)) + params.set('invite_token', encodeURIComponent(token as string)) + router.replace(`/signin?${params.toString()}`) + } + }, [checkRes, router, searchParams, token]) return (
{ diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index f143c2fcef..1b4377c10a 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie import type { Operation } from './app-operations' import AppOperations from './app-operations' import dynamic from 'next/dynamic' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { diff --git a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx index 3c5d38dd82..04634906af 100644 --- a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx @@ -16,7 +16,7 @@ import AppInfo from './app-info' import NavLink from './navLink' import { useStore as useAppStore } from '@/app/components/app/store' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' type Props = { diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index ff110f70bd..dc46af2d02 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import ActionButton from '../../base/action-button' import { RiMoreFill } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Menu from './menu' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' diff --git a/web/app/components/app-sidebar/dataset-info/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/index.spec.tsx new file mode 100644 index 0000000000..3674be6658 --- /dev/null +++ b/web/app/components/app-sidebar/dataset-info/index.spec.tsx @@ -0,0 +1,379 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DatasetInfo from './index' +import Dropdown from './dropdown' +import Menu from './menu' +import MenuItem from './menu-item' +import type { DataSet } from '@/models/datasets' +import { + ChunkingMode, + DataSourceType, + DatasetPermission, +} from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { RiEditLine } from '@remixicon/react' + +let mockDataset: DataSet +let mockIsDatasetOperator = false +const mockReplace = jest.fn() +const mockInvalidDatasetList = jest.fn() +const mockInvalidDatasetDetail = jest.fn() +const mockExportPipeline = jest.fn() +const mockCheckIsUsedInApp = jest.fn() +const mockDeleteDataset = jest.fn() + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: { + icon: '📙', + icon_background: '#FFF4ED', + icon_type: 'emoji', + icon_url: '', + }, + description: 'Dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1690000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 1, + total_document_count: 1, + word_count: 1000, + provider: 'internal', + embedding_model: 'text-embedding-3', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 0, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'rag_pipeline', + enable_api: false, + is_multimodal: false, + ...overrides, +}) + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), +})) + +jest.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }), +})) + +jest.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) => + selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }), +})) + +jest.mock('@/service/knowledge/use-dataset', () => ({ + datasetDetailQueryKeyPrefix: ['dataset', 'detail'], + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +jest.mock('@/service/use-base', () => ({ + useInvalid: () => mockInvalidDatasetDetail, +})) + +jest.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipeline, + }), +})) + +jest.mock('@/service/datasets', () => ({ + checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args), + deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args), +})) + +jest.mock('@/hooks/use-knowledge', () => ({ + useKnowledge: () => ({ + formatIndexingTechniqueAndMethod: () => 'indexing-technique', + }), +})) + +jest.mock('@/app/components/datasets/rename-modal', () => ({ + __esModule: true, + default: ({ + show, + onClose, + onSuccess, + }: { + show: boolean + onClose: () => void + onSuccess?: () => void + }) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +const openMenu = async (user: ReturnType) => { + const trigger = screen.getByRole('button') + await user.click(trigger) +} + +describe('DatasetInfo', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset() + mockIsDatasetOperator = false + }) + + // Rendering of dataset summary details based on expand and dataset state. + describe('Rendering', () => { + it('should show dataset details when expanded', () => { + // Arrange + mockDataset = createDataset({ is_published: true }) + render() + + // Assert + expect(screen.getByText('Dataset Name')).toBeInTheDocument() + expect(screen.getByText('Dataset description')).toBeInTheDocument() + expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument() + expect(screen.getByText('indexing-technique')).toBeInTheDocument() + }) + + it('should show external tag when provider is external', () => { + // Arrange + mockDataset = createDataset({ provider: 'external', is_published: false }) + render() + + // Assert + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument() + }) + + it('should hide detailed fields when collapsed', () => { + // Arrange + render() + + // Assert + expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument() + expect(screen.queryByText('Dataset description')).not.toBeInTheDocument() + }) + }) +}) + +describe('MenuItem', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Event handling for menu item interactions. + describe('Interactions', () => { + it('should call handler when clicked', async () => { + const user = userEvent.setup() + const handleClick = jest.fn() + // Arrange + render() + + // Act + await user.click(screen.getByText('Edit')) + + // Assert + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) +}) + +describe('Menu', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset() + }) + + // Rendering of menu options based on runtime mode and delete visibility. + describe('Rendering', () => { + it('should show edit, export, and delete options when rag pipeline and deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'rag_pipeline' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument() + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + + it('should hide export and delete options when not rag pipeline and not deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'general' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument() + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) +}) + +describe('Dropdown', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' }) + mockIsDatasetOperator = false + mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' }) + mockCheckIsUsedInApp.mockResolvedValue({ is_using: false }) + mockDeleteDataset.mockResolvedValue({}) + if (!('createObjectURL' in URL)) { + Object.defineProperty(URL, 'createObjectURL', { + value: jest.fn(), + writable: true, + }) + } + if (!('revokeObjectURL' in URL)) { + Object.defineProperty(URL, 'revokeObjectURL', { + value: jest.fn(), + writable: true, + }) + } + }) + + // Rendering behavior based on workspace role. + describe('Rendering', () => { + it('should hide delete option when user is dataset operator', async () => { + const user = userEvent.setup() + // Arrange + mockIsDatasetOperator = true + render() + + // Act + await openMenu(user) + + // Assert + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) + + // User interactions that trigger modals and exports. + describe('Interactions', () => { + it('should open rename modal when edit is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.edit')) + + // Assert + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + }) + + it('should export pipeline when export is clicked', async () => { + const user = userEvent.setup() + const anchorClickSpy = jest.spyOn(HTMLAnchorElement.prototype, 'click') + const createObjectURLSpy = jest.spyOn(URL, 'createObjectURL') + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('datasetPipeline.operations.exportPipeline')) + + // Assert + await waitFor(() => { + expect(mockExportPipeline).toHaveBeenCalledWith({ + pipelineId: 'pipeline-1', + include: false, + }) + }) + expect(createObjectURLSpy).toHaveBeenCalledTimes(1) + expect(anchorClickSpy).toHaveBeenCalledTimes(1) + }) + + it('should show delete confirmation when delete is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + + // Assert + await waitFor(() => { + expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument() + }) + }) + + it('should delete dataset and redirect when confirm is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' })) + + // Assert + await waitFor(() => { + expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1') + }) + expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1) + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + }) +}) diff --git a/web/app/components/app-sidebar/dataset-info/index.tsx b/web/app/components/app-sidebar/dataset-info/index.tsx index 44b0baa72b..bace656d54 100644 --- a/web/app/components/app-sidebar/dataset-info/index.tsx +++ b/web/app/components/app-sidebar/dataset-info/index.tsx @@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import type { DataSet } from '@/models/datasets' import { DOC_FORM_TEXT } from '@/models/datasets' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Dropdown from './dropdown' type DatasetInfoProps = { diff --git a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx index ac07333712..cf380d00d2 100644 --- a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx @@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon' import Divider from '../base/divider' import NavLink from './navLink' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import Effect from '../base/effect' import Dropdown from './dataset-info/dropdown' diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index 86de2e2034..fe52c4cfa2 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { useHover, useKeyPress } from 'ahooks' import ToggleButton from './toggle-button' diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index ad90b91250..f6d8e57682 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -2,7 +2,7 @@ import React from 'react' import { useSelectedLayoutSegment } from 'next/navigation' import Link from 'next/link' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { RemixiconComponentType } from '@remixicon/react' export type NavIcon = React.ComponentType< @@ -42,7 +42,7 @@ const NavLink = ({ const NavIcon = isActive ? iconMap.selected : iconMap.normal const renderIcon = () => ( -
+
) @@ -53,21 +53,17 @@ const NavLink = ({ key={name} type='button' disabled - className={classNames( - 'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', - 'pl-3 pr-1', - )} + className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', + 'pl-3 pr-1')} title={mode === 'collapse' ? name : ''} aria-disabled > {renderIcon()} {name} @@ -79,22 +75,18 @@ const NavLink = ({ {renderIcon()} {name} diff --git a/web/app/components/app-sidebar/toggle-button.tsx b/web/app/components/app-sidebar/toggle-button.tsx index 8de6f887f6..4f69adfc34 100644 --- a/web/app/components/app-sidebar/toggle-button.tsx +++ b/web/app/components/app-sidebar/toggle-button.tsx @@ -1,7 +1,7 @@ import React from 'react' import Button from '../base/button' import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Tooltip from '../base/tooltip' import { useTranslation } from 'react-i18next' import { getKeyboardKeyNameBySystem } from '../workflow/utils' diff --git a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx index 356f813afc..f226adf22b 100644 --- a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx @@ -2,12 +2,6 @@ import React from 'react' import { fireEvent, render, screen } from '@testing-library/react' import EditItem, { EditItemType } from './index' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - describe('AddAnnotationModal/EditItem', () => { test('should render query inputs with user avatar and placeholder strings', () => { render( diff --git a/web/app/components/app/annotation/batch-action.spec.tsx b/web/app/components/app/annotation/batch-action.spec.tsx new file mode 100644 index 0000000000..36440fc044 --- /dev/null +++ b/web/app/components/app/annotation/batch-action.spec.tsx @@ -0,0 +1,42 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchAction from './batch-action' + +describe('BatchAction', () => { + const baseProps = { + selectedIds: ['1', '2', '3'], + onBatchDelete: jest.fn(), + onCancel: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should show the selected count and trigger cancel action', () => { + render() + + expect(screen.getByText('3')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + expect(baseProps.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should confirm before running batch delete', async () => { + const onBatchDelete = jest.fn().mockResolvedValue(undefined) + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' })) + await screen.findByText('appAnnotation.list.delete.title') + + await act(async () => { + fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1]) + }) + + await waitFor(() => { + expect(onBatchDelete).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-action.tsx b/web/app/components/app/annotation/batch-action.tsx index 6e80d0c4c8..6ff392d17e 100644 --- a/web/app/components/app/annotation/batch-action.tsx +++ b/web/app/components/app/annotation/batch-action.tsx @@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import Divider from '@/app/components/base/divider' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' const i18nPrefix = 'appAnnotation.batchAction' @@ -38,7 +38,7 @@ const BatchAction: FC = ({ setIsNotDeleting() } return ( -
+
diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx new file mode 100644 index 0000000000..7d360cfc1b --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx @@ -0,0 +1,72 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import CSVDownload from './csv-downloader' +import I18nContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { Locale } from '@/i18n-config' + +const downloaderProps: any[] = [] + +jest.mock('react-papaparse', () => ({ + useCSVDownloader: jest.fn(() => ({ + CSVDownloader: ({ children, ...props }: any) => { + downloaderProps.push(props) + return
{children}
+ }, + Type: { Link: 'link' }, + })), +})) + +const renderWithLocale = (locale: Locale) => { + return render( + + + , + ) +} + +describe('CSVDownload', () => { + const englishTemplate = [ + ['question', 'answer'], + ['question1', 'answer1'], + ['question2', 'answer2'], + ] + const chineseTemplate = [ + ['问题', '答案'], + ['问题 1', '答案 1'], + ['问题 2', '答案 2'], + ] + + beforeEach(() => { + downloaderProps.length = 0 + }) + + it('should render the structure preview and pass English template data by default', () => { + renderWithLocale('en-US' as Locale) + + expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument() + + expect(downloaderProps[0]).toMatchObject({ + filename: 'template-en-US', + type: 'link', + bom: true, + data: englishTemplate, + }) + }) + + it('should switch to the Chinese template when locale matches the secondary language', () => { + const locale = LanguagesSupported[1] as Locale + renderWithLocale(locale) + + expect(downloaderProps[0]).toMatchObject({ + filename: `template-${locale}`, + data: chineseTemplate, + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index ccad46b860..c9766135df 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { RiDeleteBinLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' import Button from '@/app/components/base/button' diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..5527340895 --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx @@ -0,0 +1,164 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchModal, { ProcessStatus } from './index' +import { useProviderContext } from '@/context/provider-context' +import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation' +import type { IBatchModalProps } from './index' +import Toast from '@/app/components/base/toast' + +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: jest.fn(), + }, +})) + +jest.mock('@/service/annotation', () => ({ + annotationBatchImport: jest.fn(), + checkAnnotationBatchImportProgress: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('./csv-downloader', () => ({ + __esModule: true, + default: () =>
, +})) + +let lastUploadedFile: File | undefined + +jest.mock('./csv-uploader', () => ({ + __esModule: true, + default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => ( +
+ + {file && {file.name}} +
+ ), +})) + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +const mockNotify = Toast.notify as jest.Mock +const useProviderContextMock = useProviderContext as jest.Mock +const annotationBatchImportMock = annotationBatchImport as jest.Mock +const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as jest.Mock + +const renderComponent = (props: Partial = {}) => { + const mergedProps: IBatchModalProps = { + appId: 'app-id', + isShow: true, + onCancel: jest.fn(), + onAdded: jest.fn(), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('BatchModal', () => { + beforeEach(() => { + jest.clearAllMocks() + lastUploadedFile = undefined + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should disable run action and show billing hint when annotation quota is full', () => { + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 10 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: true, + }) + + renderComponent() + + expect(screen.getByTestId('annotation-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled() + }) + + it('should reset uploader state when modal closes and allow manual cancellation', () => { + const { rerender, props } = renderComponent() + + fireEvent.click(screen.getByTestId('mock-uploader')) + expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv') + + rerender() + rerender() + + expect(screen.queryByTestId('selected-file')).toBeNull() + + fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' })) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should submit the csv file, poll status, and notify when import completes', async () => { + jest.useFakeTimers() + const { props } = renderComponent() + const fileTrigger = screen.getByTestId('mock-uploader') + fireEvent.click(fileTrigger) + + const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' }) + expect(runButton).not.toBeDisabled() + + annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + checkAnnotationBatchImportProgressMock + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED }) + + await act(async () => { + fireEvent.click(runButton) + }) + + await waitFor(() => { + expect(annotationBatchImportMock).toHaveBeenCalledTimes(1) + }) + + const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData + expect(formData.get('file')).toBe(lastUploadedFile) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1) + }) + + await act(async () => { + jest.runOnlyPendingTimers() + }) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'appAnnotation.batchModal.completed', + }) + expect(props.onAdded).toHaveBeenCalledTimes(1) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + jest.useRealTimers() + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx index 1f32e55928..95a5586292 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx @@ -245,7 +245,7 @@ describe('EditItem', () => { expect(mockSave).toHaveBeenCalledWith('Test save content') }) - it('should show delete option when content changes', async () => { + it('should show delete option and restore original content when delete is clicked', async () => { // Arrange const mockSave = jest.fn().mockResolvedValue(undefined) const props = { @@ -267,7 +267,13 @@ describe('EditItem', () => { await user.click(screen.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(mockSave).toHaveBeenCalledWith('Modified content') + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(await screen.findByText('common.operation.delete')).toBeInTheDocument() + + await user.click(screen.getByText('common.operation.delete')) + + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() }) it('should handle keyboard interactions in edit mode', async () => { @@ -393,5 +399,68 @@ describe('EditItem', () => { expect(screen.queryByRole('textbox')).not.toBeInTheDocument() expect(screen.getByText('Test content')).toBeInTheDocument() }) + + it('should handle save failure gracefully in edit mode', async () => { + // Arrange + const mockSave = jest.fn().mockRejectedValueOnce(new Error('Save failed')) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Enter edit mode and save (should fail) + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.type(textarea, 'New content') + + // Save should fail but not throw + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert - Should remain in edit mode when save fails + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(mockSave).toHaveBeenCalledWith('New content') + }) + + it('should handle delete action failure gracefully', async () => { + // Arrange + const mockSave = jest.fn() + .mockResolvedValueOnce(undefined) // First save succeeds + .mockRejectedValueOnce(new Error('Delete failed')) // Delete fails + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Edit content to show delete button + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified content') + + // Save to create new content + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + await screen.findByText('common.operation.delete') + + // Click delete (should fail but not throw) + await user.click(screen.getByText('common.operation.delete')) + + // Assert - Delete action should handle error gracefully + expect(mockSave).toHaveBeenCalledTimes(2) + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + + // When delete fails, the delete button should still be visible (state not changed) + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + expect(screen.getByText('Modified content')).toBeInTheDocument() + }) }) }) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx index e808d0b48a..6ba830967d 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx @@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react' import { Robot, User } from '@/app/components/base/icons/src/public/avatar' import Textarea from '@/app/components/base/textarea' import Button from '@/app/components/base/button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export enum EditItemType { Query = 'query', @@ -52,8 +52,14 @@ const EditItem: FC = ({ }, [content]) const handleSave = async () => { - await onSave(newContent) - setIsEdit(false) + try { + await onSave(newContent) + setIsEdit(false) + } + catch { + // Keep edit mode open when save fails + // Error notification is handled by the parent component + } } const handleCancel = () => { @@ -96,9 +102,16 @@ const EditItem: FC = ({
·
{ - setNewContent(content) - onSave(content) + onClick={async () => { + try { + await onSave(content) + // Only update UI state after successful delete + setNewContent(content) + } + catch { + // Delete action failed - error is already handled by parent + // UI state remains unchanged, user can retry + } }} >
diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx index a2e2527605..bdc991116c 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx @@ -1,4 +1,4 @@ -import { render, screen } from '@testing-library/react' +import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast' import EditAnnotationModal from './index' @@ -405,4 +405,276 @@ describe('EditAnnotationModal', () => { expect(editLinks).toHaveLength(1) // Only answer should have edit button }) }) + + // Error Handling (CRITICAL for coverage) + describe('Error Handling', () => { + it('should show error toast and skip callbacks when addAnnotation fails', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock API failure + mockAddAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Find and click edit link for query + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + // Find textarea and enter new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + // Click save button + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when addAnnotation error has no message', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + mockAddAnnotation.mockRejectedValueOnce({}) + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show error toast and skip callbacks when editAnnotation fails', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + // Mock API failure + mockEditAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when editAnnotation error is not an Error instance', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + mockEditAnnotation.mockRejectedValueOnce('oops') + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + }) + + // Billing & Plan Features + describe('Billing & Plan Features', () => { + it('should show createdAt time when provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + createdAt: 1701381000, // 2023-12-01 10:30:00 + } + + // Act + render() + + // Assert - Check that the formatted time appears somewhere in the component + const container = screen.getByRole('dialog') + expect(container).toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should not show createdAt when not provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + // createdAt is undefined + } + + // Act + render() + + // Assert - Should not contain any timestamp + const container = screen.getByRole('dialog') + expect(container).not.toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should display remove section when annotationId exists', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert - Should have remove functionality + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + }) + + // Toast Notifications (Success) + describe('Toast Notifications', () => { + it('should show success notification when save operation completes', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionSuccess', + type: 'success', + }) + }) + }) + }) + + // React.memo Performance Testing + describe('React.memo Performance', () => { + it('should not re-render when props are the same', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with same props + rerender() + + // Assert - Component should still be visible (no errors thrown) + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should re-render when props change', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with different props + const newProps = { ...props, query: 'New query content' } + rerender() + + // Assert - Should show new content + expect(screen.getByText('New query content')).toBeInTheDocument() + }) + }) }) diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.tsx index 2961ce393c..6172a215e4 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.tsx @@ -53,27 +53,39 @@ const EditAnnotationModal: FC = ({ postQuery = editedContent else postAnswer = editedContent - if (!isAdd) { - await editAnnotation(appId, annotationId, { - message_id: messageId, - question: postQuery, - answer: postAnswer, - }) - onEdited(postQuery, postAnswer) - } - else { - const res: any = await addAnnotation(appId, { - question: postQuery, - answer: postAnswer, - message_id: messageId, - }) - onAdded(res.id, res.account?.name, postQuery, postAnswer) - } + try { + if (!isAdd) { + await editAnnotation(appId, annotationId, { + message_id: messageId, + question: postQuery, + answer: postAnswer, + }) + onEdited(postQuery, postAnswer) + } + else { + const res = await addAnnotation(appId, { + question: postQuery, + answer: postAnswer, + message_id: messageId, + }) + onAdded(res.id, res.account?.name ?? '', postQuery, postAnswer) + } - Toast.notify({ - message: t('common.api.actionSuccess') as string, - type: 'success', - }) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + } + catch (error) { + const fallbackMessage = t('common.api.actionFailed') as string + const message = error instanceof Error && error.message ? error.message : fallbackMessage + Toast.notify({ + message, + type: 'error', + }) + // Re-throw to preserve edit mode behavior for UI components + throw error + } } const [showModal, setShowModal] = useState(false) diff --git a/web/app/components/app/annotation/empty-element.spec.tsx b/web/app/components/app/annotation/empty-element.spec.tsx new file mode 100644 index 0000000000..56ebb96121 --- /dev/null +++ b/web/app/components/app/annotation/empty-element.spec.tsx @@ -0,0 +1,13 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import EmptyElement from './empty-element' + +describe('EmptyElement', () => { + it('should render the empty state copy and supporting icon', () => { + const { container } = render() + + expect(screen.getByText('appAnnotation.noData.title')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.noData.description')).toBeInTheDocument() + expect(container.querySelector('svg')).not.toBeNull() + }) +}) diff --git a/web/app/components/app/annotation/filter.spec.tsx b/web/app/components/app/annotation/filter.spec.tsx new file mode 100644 index 0000000000..6260ff7668 --- /dev/null +++ b/web/app/components/app/annotation/filter.spec.tsx @@ -0,0 +1,70 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import Filter, { type QueryParam } from './filter' +import useSWR from 'swr' + +jest.mock('swr', () => ({ + __esModule: true, + default: jest.fn(), +})) + +jest.mock('@/service/log', () => ({ + fetchAnnotationsCount: jest.fn(), +})) + +const mockUseSWR = useSWR as unknown as jest.Mock + +describe('Filter', () => { + const appId = 'app-1' + const childContent = 'child-content' + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render nothing until annotation count is fetched', () => { + mockUseSWR.mockReturnValue({ data: undefined }) + + const { container } = render( + +
{childContent}
+
, + ) + + expect(container.firstChild).toBeNull() + expect(mockUseSWR).toHaveBeenCalledWith( + { url: `/apps/${appId}/annotations/count` }, + expect.any(Function), + ) + }) + + it('should propagate keyword changes and clearing behavior', () => { + mockUseSWR.mockReturnValue({ data: { total: 20 } }) + const queryParams: QueryParam = { keyword: 'prefill' } + const setQueryParams = jest.fn() + + const { container } = render( + +
{childContent}
+
, + ) + + const input = screen.getByPlaceholderText('common.operation.search') as HTMLInputElement + fireEvent.change(input, { target: { value: 'updated' } }) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: 'updated' }) + + const clearButton = input.parentElement?.querySelector('div.cursor-pointer') as HTMLElement + fireEvent.click(clearButton) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: '' }) + + expect(container).toHaveTextContent(childContent) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.spec.tsx b/web/app/components/app/annotation/header-opts/index.spec.tsx new file mode 100644 index 0000000000..3d8a1fd4ef --- /dev/null +++ b/web/app/components/app/annotation/header-opts/index.spec.tsx @@ -0,0 +1,439 @@ +import * as React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import type { ComponentProps } from 'react' +import HeaderOptions from './index' +import I18NContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { AnnotationItemBasic } from '../type' +import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation' + +jest.mock('@headlessui/react', () => { + type PopoverContextValue = { open: boolean; setOpen: (open: boolean) => void } + type MenuContextValue = { open: boolean; setOpen: (open: boolean) => void } + const PopoverContext = React.createContext(null) + const MenuContext = React.createContext(null) + + const Popover = ({ children }: { children: React.ReactNode | ((props: { open: boolean }) => React.ReactNode) }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {typeof children === 'function' ? children({ open }) : children} + + ) + } + + const PopoverButton = React.forwardRef(({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + }) + + const PopoverPanel = React.forwardRef(({ children, ...props }: { children: React.ReactNode | ((props: { close: () => void }) => React.ReactNode) }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + if (!context?.open) return null + const content = typeof children === 'function' ? children({ close: () => context.setOpen(false) }) : children + return ( +
+ {content} +
+ ) + }) + + const Menu = ({ children }: { children: React.ReactNode }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {children} + + ) + } + + const MenuButton = ({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }) => { + const context = React.useContext(MenuContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + } + + const MenuItems = ({ children, ...props }: { children: React.ReactNode }) => { + const context = React.useContext(MenuContext) + if (!context?.open) return null + return ( +
+ {children} +
+ ) + } + + return { + Dialog: ({ open, children, className }: { open?: boolean; children: React.ReactNode; className?: string }) => { + if (open === false) return null + return ( +
+ {children} +
+ ) + }, + DialogBackdrop: ({ children, className, onClick }: { children?: React.ReactNode; className?: string; onClick?: () => void }) => ( +
+ {children} +
+ ), + DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + Popover, + PopoverButton, + PopoverPanel, + Menu, + MenuButton, + MenuItems, + Transition: ({ show = true, children }: { show?: boolean; children: React.ReactNode }) => (show ? <>{children} : null), + TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}, + } +}) + +let lastCSVDownloaderProps: Record | undefined +const mockCSVDownloader = jest.fn(({ children, ...props }) => { + lastCSVDownloaderProps = props + return ( +
+ {children} +
+ ) +}) + +jest.mock('react-papaparse', () => ({ + useCSVDownloader: () => ({ + CSVDownloader: (props: any) => mockCSVDownloader(props), + Type: { Link: 'link' }, + }), +})) + +jest.mock('@/service/annotation', () => ({ + fetchExportAnnotationList: jest.fn(), + clearAllAnnotations: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }), +})) + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +type HeaderOptionsProps = ComponentProps + +const renderComponent = ( + props: Partial = {}, + locale: string = LanguagesSupported[0] as string, +) => { + const defaultProps: HeaderOptionsProps = { + appId: 'test-app-id', + onAdd: jest.fn(), + onAdded: jest.fn(), + controlUpdateList: 0, + ...props, + } + + return render( + + + , + ) +} + +const openOperationsPopover = async (user: ReturnType) => { + const trigger = document.querySelector('button.btn.btn-secondary') as HTMLButtonElement + expect(trigger).toBeTruthy() + await user.click(trigger) +} + +const expandExportMenu = async (user: ReturnType) => { + await openOperationsPopover(user) + const exportLabel = await screen.findByText('appAnnotation.table.header.bulkExport') + const exportButton = exportLabel.closest('button') as HTMLButtonElement + expect(exportButton).toBeTruthy() + await user.click(exportButton) +} + +const getExportButtons = async () => { + const csvLabel = await screen.findByText('CSV') + const jsonLabel = await screen.findByText('JSONL') + const csvButton = csvLabel.closest('button') as HTMLButtonElement + const jsonButton = jsonLabel.closest('button') as HTMLButtonElement + expect(csvButton).toBeTruthy() + expect(jsonButton).toBeTruthy() + return { + csvButton, + jsonButton, + } +} + +const clickOperationAction = async ( + user: ReturnType, + translationKey: string, +) => { + const label = await screen.findByText(translationKey) + const button = label.closest('button') as HTMLButtonElement + expect(button).toBeTruthy() + await user.click(button) +} + +const mockAnnotations: AnnotationItemBasic[] = [ + { + question: 'Question 1', + answer: 'Answer 1', + }, +] + +const mockedFetchAnnotations = jest.mocked(fetchExportAnnotationList) +const mockedClearAllAnnotations = jest.mocked(clearAllAnnotations) + +describe('HeaderOptions', () => { + beforeEach(() => { + jest.clearAllMocks() + jest.useRealTimers() + mockCSVDownloader.mockClear() + lastCSVDownloaderProps = undefined + mockedFetchAnnotations.mockResolvedValue({ data: [] }) + }) + + it('should fetch annotations on mount and render enabled export actions when data exist', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + renderComponent() + + await waitFor(() => { + expect(mockedFetchAnnotations).toHaveBeenCalledWith('test-app-id') + }) + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).not.toBeDisabled() + expect(jsonButton).not.toBeDisabled() + + await waitFor(() => { + expect(lastCSVDownloaderProps).toMatchObject({ + bom: true, + filename: 'annotations-en-US', + type: 'link', + data: [ + ['Question', 'Answer'], + ['Question 1', 'Answer 1'], + ], + }) + }) + }) + + it('should disable export actions when there are no annotations', async () => { + const user = userEvent.setup() + renderComponent() + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).toBeDisabled() + expect(jsonButton).toBeDisabled() + + expect(lastCSVDownloaderProps).toMatchObject({ + data: [['Question', 'Answer']], + }) + }) + + it('should open the add annotation modal and forward the onAdd callback', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const onAdd = jest.fn().mockResolvedValue(undefined) + renderComponent({ onAdd }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalled()) + + await user.click( + screen.getByRole('button', { name: 'appAnnotation.table.header.addAnnotation' }), + ) + + await screen.findByText('appAnnotation.addModal.title') + const questionInput = screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder') + const answerInput = screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder') + + await user.type(questionInput, 'Integration question') + await user.type(answerInput, 'Integration answer') + await user.click(screen.getByRole('button', { name: 'common.operation.add' })) + + await waitFor(() => { + expect(onAdd).toHaveBeenCalledWith({ + question: 'Integration question', + answer: 'Integration answer', + }) + }) + }) + + it('should allow bulk import through the batch modal', async () => { + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.bulkImport') + + expect(await screen.findByText('appAnnotation.batchModal.title')).toBeInTheDocument() + await user.click( + screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }), + ) + expect(onAdded).not.toHaveBeenCalled() + }) + + it('should trigger JSONL download with locale-specific filename', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const originalCreateElement = document.createElement.bind(document) + const anchor = originalCreateElement('a') as HTMLAnchorElement + const clickSpy = jest.spyOn(anchor, 'click').mockImplementation(jest.fn()) + const createElementSpy = jest + .spyOn(document, 'createElement') + .mockImplementation((tagName: Parameters[0]) => { + if (tagName === 'a') + return anchor + return originalCreateElement(tagName) + }) + const objectURLSpy = jest + .spyOn(URL, 'createObjectURL') + .mockReturnValue('blob://mock-url') + const revokeSpy = jest.spyOn(URL, 'revokeObjectURL').mockImplementation(jest.fn()) + + renderComponent({}, LanguagesSupported[1] as string) + + await expandExportMenu(user) + + await waitFor(() => expect(mockCSVDownloader).toHaveBeenCalled()) + + const { jsonButton } = await getExportButtons() + await user.click(jsonButton) + + expect(createElementSpy).toHaveBeenCalled() + expect(anchor.download).toBe(`annotations-${LanguagesSupported[1]}.jsonl`) + expect(clickSpy).toHaveBeenCalled() + expect(revokeSpy).toHaveBeenCalledWith('blob://mock-url') + + const blobArg = objectURLSpy.mock.calls[0][0] as Blob + await expect(blobArg.text()).resolves.toContain('"Question 1"') + + clickSpy.mockRestore() + createElementSpy.mockRestore() + objectURLSpy.mockRestore() + revokeSpy.mockRestore() + }) + + it('should clear all annotations when confirmation succeeds', async () => { + mockedClearAllAnnotations.mockResolvedValue(undefined) + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalledWith('test-app-id') + expect(onAdded).toHaveBeenCalled() + }) + }) + + it('should handle clear all failures gracefully', async () => { + const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn()) + mockedClearAllAnnotations.mockRejectedValue(new Error('network')) + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalled() + expect(onAdded).not.toHaveBeenCalled() + expect(consoleSpy).toHaveBeenCalled() + }) + + consoleSpy.mockRestore() + }) + + it('should refetch annotations when controlUpdateList changes', async () => { + const view = renderComponent({ controlUpdateList: 0 }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(1)) + + view.rerender( + + + , + ) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(2)) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 024f75867c..5f8ef658e7 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -17,7 +17,7 @@ import Button from '../../../base/button' import AddAnnotationModal from '../add-annotation-modal' import type { AnnotationItemBasic } from '../type' import BatchAddModal from '../batch-add-annotation-modal' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import CustomPopover from '@/app/components/base/popover' import { FileDownload02, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' diff --git a/web/app/components/app/annotation/index.spec.tsx b/web/app/components/app/annotation/index.spec.tsx new file mode 100644 index 0000000000..4971f5173c --- /dev/null +++ b/web/app/components/app/annotation/index.spec.tsx @@ -0,0 +1,233 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import Annotation from './index' +import type { AnnotationItem } from './type' +import { JobStatus } from './type' +import { type App, AppModeEnum } from '@/types/app' +import { + addAnnotation, + delAnnotation, + delAnnotations, + fetchAnnotationConfig, + fetchAnnotationList, + queryAnnotationJobStatus, +} from '@/service/annotation' +import { useProviderContext } from '@/context/provider-context' +import Toast from '@/app/components/base/toast' + +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { notify: jest.fn() }, +})) + +jest.mock('ahooks', () => ({ + useDebounce: (value: any) => value, +})) + +jest.mock('@/service/annotation', () => ({ + addAnnotation: jest.fn(), + delAnnotation: jest.fn(), + delAnnotations: jest.fn(), + fetchAnnotationConfig: jest.fn(), + editAnnotation: jest.fn(), + fetchAnnotationList: jest.fn(), + queryAnnotationJobStatus: jest.fn(), + updateAnnotationScore: jest.fn(), + updateAnnotationStatus: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('./filter', () => ({ children }: { children: React.ReactNode }) => ( +
{children}
+)) + +jest.mock('./empty-element', () => () =>
) + +jest.mock('./header-opts', () => (props: any) => ( +
+ +
+)) + +let latestListProps: any + +jest.mock('./list', () => (props: any) => { + latestListProps = props + if (!props.list.length) + return
+ return ( +
+ + + +
+ ) +}) + +jest.mock('./view-annotation-modal', () => (props: any) => { + if (!props.isShow) + return null + return ( +
+
{props.item.question}
+ + +
+ ) +}) + +jest.mock('@/app/components/base/pagination', () => () =>
) +jest.mock('@/app/components/base/loading', () => () =>
) +jest.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => (props: any) => props.isShow ?
: null) +jest.mock('@/app/components/billing/annotation-full/modal', () => (props: any) => props.show ?
: null) + +const mockNotify = Toast.notify as jest.Mock +const addAnnotationMock = addAnnotation as jest.Mock +const delAnnotationMock = delAnnotation as jest.Mock +const delAnnotationsMock = delAnnotations as jest.Mock +const fetchAnnotationConfigMock = fetchAnnotationConfig as jest.Mock +const fetchAnnotationListMock = fetchAnnotationList as jest.Mock +const queryAnnotationJobStatusMock = queryAnnotationJobStatus as jest.Mock +const useProviderContextMock = useProviderContext as jest.Mock + +const appDetail = { + id: 'app-id', + mode: AppModeEnum.CHAT, +} as App + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-1', + question: overrides.question ?? 'Question 1', + answer: overrides.answer ?? 'Answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const renderComponent = () => render() + +describe('Annotation', () => { + beforeEach(() => { + jest.clearAllMocks() + latestListProps = undefined + fetchAnnotationConfigMock.mockResolvedValue({ + id: 'config-id', + enabled: false, + embedding_model: { + embedding_model_name: 'model', + embedding_provider_name: 'provider', + }, + score_threshold: 0.5, + }) + fetchAnnotationListMock.mockResolvedValue({ data: [], total: 0 }) + queryAnnotationJobStatusMock.mockResolvedValue({ job_status: JobStatus.completed }) + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should render empty element when no annotations are returned', async () => { + renderComponent() + + expect(await screen.findByTestId('empty-element')).toBeInTheDocument() + expect(fetchAnnotationListMock).toHaveBeenCalledWith(appDetail.id, expect.objectContaining({ + page: 1, + keyword: '', + })) + }) + + it('should handle annotation creation and refresh list data', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + addAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + + await screen.findByTestId('list') + fireEvent.click(screen.getByTestId('trigger-add')) + + await waitFor(() => { + expect(addAnnotationMock).toHaveBeenCalledWith(appDetail.id, { question: 'new question', answer: 'new answer' }) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + message: 'common.api.actionSuccess', + type: 'success', + })) + }) + expect(fetchAnnotationListMock).toHaveBeenCalledTimes(2) + }) + + it('should support viewing items and running batch deletion success flow', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + delAnnotationsMock.mockResolvedValue(undefined) + delAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + await waitFor(() => { + expect(delAnnotationsMock).toHaveBeenCalledWith(appDetail.id, [annotation.id]) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + expect(latestListProps.selectedIds).toEqual([]) + }) + + fireEvent.click(screen.getByTestId('list-view')) + expect(screen.getByTestId('view-modal')).toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByTestId('view-modal-remove')) + }) + await waitFor(() => { + expect(delAnnotationMock).toHaveBeenCalledWith(appDetail.id, annotation.id) + }) + }) + + it('should show an error notification when batch deletion fails', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + const error = new Error('failed') + delAnnotationsMock.mockRejectedValue(error) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: error.message, + }) + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + }) +}) diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 32d0c799fc..2d639c91e4 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -25,7 +25,7 @@ import { sleep } from '@/utils' import { useProviderContext } from '@/context/provider-context' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import { type App, AppModeEnum } from '@/types/app' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { delAnnotations } from '@/service/annotation' type Props = { diff --git a/web/app/components/app/annotation/list.spec.tsx b/web/app/components/app/annotation/list.spec.tsx new file mode 100644 index 0000000000..9f8d4c8855 --- /dev/null +++ b/web/app/components/app/annotation/list.spec.tsx @@ -0,0 +1,116 @@ +import React from 'react' +import { fireEvent, render, screen, within } from '@testing-library/react' +import List from './list' +import type { AnnotationItem } from './type' + +const mockFormatTime = jest.fn(() => 'formatted-time') + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question 1', + answer: overrides.answer ?? 'answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 2, +}) + +const getCheckboxes = (container: HTMLElement) => container.querySelectorAll('[data-testid^="checkbox"]') + +describe('List', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render annotation rows and call onView when clicking a row', () => { + const item = createAnnotation() + const onView = jest.fn() + + render( + , + ) + + fireEvent.click(screen.getByText(item.question)) + + expect(onView).toHaveBeenCalledWith(item) + expect(mockFormatTime).toHaveBeenCalledWith(item.created_at, 'appLog.dateTimeFormat') + }) + + it('should toggle single and bulk selection states', () => { + const list = [createAnnotation({ id: 'a', question: 'A' }), createAnnotation({ id: 'b', question: 'B' })] + const onSelectedIdsChange = jest.fn() + const { container, rerender } = render( + , + ) + + const checkboxes = getCheckboxes(container) + fireEvent.click(checkboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a']) + + rerender( + , + ) + const updatedCheckboxes = getCheckboxes(container) + fireEvent.click(updatedCheckboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith([]) + + fireEvent.click(updatedCheckboxes[0]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a', 'b']) + }) + + it('should confirm before removing an annotation and expose batch actions', async () => { + const item = createAnnotation({ id: 'to-delete', question: 'Delete me' }) + const onRemove = jest.fn() + render( + , + ) + + const row = screen.getByText(item.question).closest('tr') as HTMLTableRowElement + const actionButtons = within(row).getAllByRole('button') + fireEvent.click(actionButtons[1]) + + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + expect(onRemove).toHaveBeenCalledWith(item.id) + + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/annotation/list.tsx b/web/app/components/app/annotation/list.tsx index 4135b4362e..62a0c50e60 100644 --- a/web/app/components/app/annotation/list.tsx +++ b/web/app/components/app/annotation/list.tsx @@ -7,7 +7,7 @@ import type { AnnotationItem } from './type' import RemoveAnnotationConfirmModal from './remove-annotation-confirm-modal' import ActionButton from '@/app/components/base/action-button' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Checkbox from '@/app/components/base/checkbox' import BatchAction from './batch-action' diff --git a/web/app/components/app/annotation/type.ts b/web/app/components/app/annotation/type.ts index 5df6f51ace..e2f2264f07 100644 --- a/web/app/components/app/annotation/type.ts +++ b/web/app/components/app/annotation/type.ts @@ -12,6 +12,12 @@ export type AnnotationItem = { hit_count: number } +export type AnnotationCreateResponse = AnnotationItem & { + account?: { + name?: string + } +} + export type HitHistoryItem = { id: string question: string diff --git a/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..dec0ad0c01 --- /dev/null +++ b/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx @@ -0,0 +1,158 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import ViewAnnotationModal from './index' +import type { AnnotationItem, HitHistoryItem } from '../type' +import { fetchHitHistoryList } from '@/service/annotation' + +const mockFormatTime = jest.fn(() => 'formatted-time') + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +jest.mock('@/service/annotation', () => ({ + fetchHitHistoryList: jest.fn(), +})) + +jest.mock('../edit-annotation-modal/edit-item', () => { + const EditItemType = { + Query: 'query', + Answer: 'answer', + } + return { + __esModule: true, + default: ({ type, content, onSave }: { type: string; content: string; onSave: (value: string) => void }) => ( +
+
{content}
+ +
+ ), + EditItemType, + } +}) + +const fetchHitHistoryListMock = fetchHitHistoryList as jest.Mock + +const createAnnotationItem = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question', + answer: overrides.answer ?? 'answer', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const createHitHistoryItem = (overrides: Partial = {}): HitHistoryItem => ({ + id: overrides.id ?? 'hit-id', + question: overrides.question ?? 'query', + match: overrides.match ?? 'match', + response: overrides.response ?? 'response', + source: overrides.source ?? 'source', + score: overrides.score ?? 0.42, + created_at: overrides.created_at ?? 1700000000, +}) + +const renderComponent = (props?: Partial>) => { + const item = createAnnotationItem() + const mergedProps: React.ComponentProps = { + appId: 'app-id', + isShow: true, + onHide: jest.fn(), + item, + onSave: jest.fn().mockResolvedValue(undefined), + onRemove: jest.fn().mockResolvedValue(undefined), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('ViewAnnotationModal', () => { + beforeEach(() => { + jest.clearAllMocks() + fetchHitHistoryListMock.mockResolvedValue({ data: [], total: 0 }) + }) + + it('should render annotation tab and allow saving updated query', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-query')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith('query-updated', props.item.answer) + }) + }) + + it('should render annotation tab and allow saving updated answer', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-answer')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith(props.item.question, 'answer-updated') + }, + ) + }) + + it('should switch to hit history tab and show no data message', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByText('appAnnotation.viewModal.hitHistory')) + + // Assert + expect(await screen.findByText('appAnnotation.viewModal.noHitHistory')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(props.item.created_at, 'appLog.dateTimeFormat') + }) + + it('should render hit history entries with pagination badge when data exists', async () => { + const hits = [createHitHistoryItem({ question: 'user input' }), createHitHistoryItem({ id: 'hit-2', question: 'second' })] + fetchHitHistoryListMock.mockResolvedValue({ data: hits, total: 15 }) + + renderComponent() + + fireEvent.click(await screen.findByText('appAnnotation.viewModal.hitHistory')) + + expect(await screen.findByText('user input')).toBeInTheDocument() + expect(screen.getByText('15 appAnnotation.viewModal.hits')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(hits[0].created_at, 'appLog.dateTimeFormat') + }) + + it('should confirm before removing the annotation and hide on success', async () => { + const { props } = renderComponent() + + fireEvent.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(props.onRemove).toHaveBeenCalledTimes(1) + expect(props.onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/view-annotation-modal/index.tsx b/web/app/components/app/annotation/view-annotation-modal/index.tsx index 8426ab0005..d21b177098 100644 --- a/web/app/components/app/annotation/view-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/view-annotation-modal/index.tsx @@ -14,7 +14,7 @@ import TabSlider from '@/app/components/base/tab-slider-plain' import { fetchHitHistoryList } from '@/service/annotation' import { APP_PAGE_LIMIT } from '@/config' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { appId: string diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index ee3fa9650b..99cf6d7074 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -2,7 +2,7 @@ import { Fragment, useCallback } from 'react' import type { ReactNode } from 'react' import { Dialog, Transition } from '@headlessui/react' import { RiCloseLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type DialogProps = { className?: string diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx new file mode 100644 index 0000000000..ea0e17de2e --- /dev/null +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -0,0 +1,389 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AccessControl from './index' +import AccessControlDialog from './access-control-dialog' +import AccessControlItem from './access-control-item' +import AddMemberOrGroupDialog from './add-member-or-group-pop' +import SpecificGroupsOrMembers from './specific-groups-or-members' +import useAccessControlStore from '@/context/access-control-store' +import { useGlobalPublicStore } from '@/context/global-public-context' +import type { AccessControlAccount, AccessControlGroup, Subject } from '@/models/access-control' +import { AccessMode, SubjectType } from '@/models/access-control' +import Toast from '../../base/toast' +import { defaultSystemFeatures } from '@/types/feature' +import type { App } from '@/types/app' + +const mockUseAppWhiteListSubjects = jest.fn() +const mockUseSearchForWhiteListCandidates = jest.fn() +const mockMutateAsync = jest.fn() +const mockUseUpdateAccessMode = jest.fn(() => ({ + isPending: false, + mutateAsync: mockMutateAsync, +})) + +jest.mock('@/context/app-context', () => ({ + useSelector: (selector: (value: { userProfile: { email: string; id?: string; name?: string; avatar?: string; avatar_url?: string; is_password_set?: boolean } }) => T) => selector({ + userProfile: { + id: 'current-user', + name: 'Current User', + email: 'member@example.com', + avatar: '', + avatar_url: '', + is_password_set: true, + }, + }), +})) + +jest.mock('@/service/common', () => ({ + fetchCurrentWorkspace: jest.fn(), + fetchLangGeniusVersion: jest.fn(), + fetchUserProfile: jest.fn(), + getSystemFeatures: jest.fn(), +})) + +jest.mock('@/service/access-control', () => ({ + useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), + useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), + useUpdateAccessMode: () => mockUseUpdateAccessMode(), +})) + +jest.mock('@headlessui/react', () => { + const DialogComponent: any = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + DialogComponent.Panel = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogTitle = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogDescription = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const TransitionChild = ({ children }: any) => ( + <>{typeof children === 'function' ? children({}) : children} + ) + const Transition = ({ show = true, children }: any) => ( + show ? <>{typeof children === 'function' ? children({}) : children} : null + ) + Transition.Child = TransitionChild + return { + Dialog: DialogComponent, + Transition, + DialogTitle, + Description: DialogDescription, + } +}) + +jest.mock('ahooks', () => { + const actual = jest.requireActual('ahooks') + return { + ...actual, + useDebounce: (value: unknown) => value, + } +}) + +const createGroup = (overrides: Partial = {}): AccessControlGroup => ({ + id: 'group-1', + name: 'Group One', + groupSize: 5, + ...overrides, +} as AccessControlGroup) + +const createMember = (overrides: Partial = {}): AccessControlAccount => ({ + id: 'member-1', + name: 'Member One', + email: 'member@example.com', + avatar: '', + avatarUrl: '', + ...overrides, +} as AccessControlAccount) + +const baseGroup = createGroup() +const baseMember = createMember() +const groupSubject: Subject = { + subjectId: baseGroup.id, + subjectType: SubjectType.GROUP, + groupData: baseGroup, +} as Subject +const memberSubject: Subject = { + subjectId: baseMember.id, + subjectType: SubjectType.ACCOUNT, + accountData: baseMember, +} as Subject + +const resetAccessControlStore = () => { + useAccessControlStore.setState({ + appId: '', + specificGroups: [], + specificMembers: [], + currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS, + selectedGroupsForBreadcrumb: [], + }) +} + +const resetGlobalStore = () => { + useGlobalPublicStore.setState({ + systemFeatures: defaultSystemFeatures, + isGlobalPending: false, + }) +} + +beforeAll(() => { + class MockIntersectionObserver { + observe = jest.fn(() => undefined) + disconnect = jest.fn(() => undefined) + unobserve = jest.fn(() => undefined) + } + // @ts-expect-error jsdom does not implement IntersectionObserver + globalThis.IntersectionObserver = MockIntersectionObserver +}) + +beforeEach(() => { + jest.clearAllMocks() + resetAccessControlStore() + resetGlobalStore() + mockMutateAsync.mockResolvedValue(undefined) + mockUseUpdateAccessMode.mockReturnValue({ + isPending: false, + mutateAsync: mockMutateAsync, + }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: false, + data: { + groups: [baseGroup], + members: [baseMember], + }, + }) + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: jest.fn(), + data: { pages: [{ currPage: 1, subjects: [groupSubject, memberSubject], hasMore: false }] }, + }) +}) + +// AccessControlItem handles selected vs. unselected styling and click state updates +describe('AccessControlItem', () => { + it('should update current menu when selecting a different access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.PUBLIC }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + expect(option).toHaveClass('cursor-pointer') + + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) + + it('should keep current menu when clicking the selected access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) +}) + +// AccessControlDialog renders a headless UI dialog with a manual close control +describe('AccessControlDialog', () => { + it('should render dialog content when visible', () => { + render( + +
Dialog Content
+
, + ) + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByText('Dialog Content')).toBeInTheDocument() + }) + + it('should trigger onClose when clicking the close control', async () => { + const handleClose = jest.fn() + const { container } = render( + +
Dialog Content
+
, + ) + + const closeButton = container.querySelector('.absolute.right-5.top-5') as HTMLElement + fireEvent.click(closeButton) + + await waitFor(() => { + expect(handleClose).toHaveBeenCalledTimes(1) + }) + }) +}) + +// SpecificGroupsOrMembers syncs store state with fetched data and supports removals +describe('SpecificGroupsOrMembers', () => { + it('should render collapsed view when not in specific selection mode', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + + render() + + expect(screen.getByText('app.accessControlDialog.accessItems.specific')).toBeInTheDocument() + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + it('should show loading state while pending', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: true, + data: undefined, + }) + + const { container } = render() + + await waitFor(() => { + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + }) + + it('should render fetched groups and members and support removal', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + + render() + + await waitFor(() => { + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + const groupItem = screen.getByText(baseGroup.name).closest('div') + const groupRemove = groupItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(groupRemove) + + await waitFor(() => { + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + const memberItem = screen.getByText(baseMember.name).closest('div') + const memberRemove = memberItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(memberRemove) + + await waitFor(() => { + expect(screen.queryByText(baseMember.name)).not.toBeInTheDocument() + }) + }) +}) + +// AddMemberOrGroupDialog renders search results and updates store selections +describe('AddMemberOrGroupDialog', () => { + it('should open search popover and display candidates', async () => { + const user = userEvent.setup() + + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByPlaceholderText('app.accessControlDialog.operateGroupAndMember.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + it('should allow selecting members and expanding groups', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + const expandButton = screen.getByText('app.accessControlDialog.operateGroupAndMember.expand') + await user.click(expandButton) + expect(useAccessControlStore.getState().selectedGroupsForBreadcrumb).toEqual([baseGroup]) + + const memberLabel = screen.getByText(baseMember.name) + const memberCheckbox = memberLabel.parentElement?.previousElementSibling as HTMLElement + fireEvent.click(memberCheckbox) + + expect(useAccessControlStore.getState().specificMembers).toEqual([baseMember]) + }) + + it('should show empty state when no candidates are returned', async () => { + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: jest.fn(), + data: { pages: [] }, + }) + + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByText('app.accessControlDialog.operateGroupAndMember.noResult')).toBeInTheDocument() + }) +}) + +// AccessControl integrates dialog, selection items, and confirm flow +describe('AccessControl', () => { + it('should initialize menu from app and call update on confirm', async () => { + const onClose = jest.fn() + const onConfirm = jest.fn() + const toastSpy = jest.spyOn(Toast, 'notify').mockReturnValue({}) + useAccessControlStore.setState({ + specificGroups: [baseGroup], + specificMembers: [baseMember], + }) + const app = { + id: 'app-id-1', + access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + } as App + + render( + , + ) + + await waitFor(() => { + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.SPECIFIC_GROUPS_MEMBERS) + }) + + fireEvent.click(screen.getByText('common.operation.confirm')) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalledWith({ + appId: app.id, + accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + subjects: [ + { subjectId: baseGroup.id, subjectType: SubjectType.GROUP }, + { subjectId: baseMember.id, subjectType: SubjectType.ACCOUNT }, + ], + }) + expect(toastSpy).toHaveBeenCalled() + expect(onConfirm).toHaveBeenCalled() + }) + }) + + it('should expose the external members tip when SSO is disabled', () => { + const app = { + id: 'app-id-2', + access_mode: AccessMode.PUBLIC, + } as App + + render( + , + ) + + expect(screen.getByText('app.accessControlDialog.accessItems.external')).toBeInTheDocument() + expect(screen.getByText('app.accessControlDialog.accessItems.anyone')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index e9519aeedf..17263fdd46 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -11,7 +11,7 @@ import Input from '../../base/input' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import Loading from '../../base/loading' import useAccessControlStore from '../../../../context/access-control-store' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useSearchForWhiteListCandidates } from '@/service/access-control' import type { AccessControlAccount, AccessControlGroup, Subject, SubjectAccount, SubjectGroup } from '@/models/access-control' import { SubjectType } from '@/models/access-control' @@ -32,7 +32,7 @@ export default function AddMemberOrGroupDialog() { const anchorRef = useRef(null) useEffect(() => { - const hasMore = data?.pages?.[0].hasMore ?? false + const hasMore = data?.pages?.[0]?.hasMore ?? false let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() { setSelectedGroupsForBreadcrumb([]) }, [setSelectedGroupsForBreadcrumb]) return
- 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} + 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} {selectedGroupsForBreadcrumb.map((group, index) => { return
/ @@ -198,7 +198,7 @@ type BaseItemProps = { children: React.ReactNode } function BaseItem({ children, className }: BaseItemProps) { - return
+ return
{children}
} diff --git a/web/app/components/app/app-publisher/suggested-action.tsx b/web/app/components/app/app-publisher/suggested-action.tsx index 2535de6654..154bacc361 100644 --- a/web/app/components/app/app-publisher/suggested-action.tsx +++ b/web/app/components/app/app-publisher/suggested-action.tsx @@ -1,6 +1,6 @@ import type { HTMLProps, PropsWithChildren } from 'react' import { RiArrowRightUpLine } from '@remixicon/react' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type SuggestedActionProps = PropsWithChildren & { icon?: React.ReactNode @@ -19,11 +19,9 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, . href={disabled ? undefined : link} target='_blank' rel='noreferrer' - className={classNames( - 'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', + className={cn('flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent', - className, - )} + className)} onClick={handleClick} {...props} > diff --git a/web/app/components/app/configuration/base/feature-panel/index.tsx b/web/app/components/app/configuration/base/feature-panel/index.tsx index ec5ab96d76..c9ebfefbe5 100644 --- a/web/app/components/app/configuration/base/feature-panel/index.tsx +++ b/web/app/components/app/configuration/base/feature-panel/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC, ReactNode } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IFeaturePanelProps = { className?: string diff --git a/web/app/components/app/configuration/base/operation-btn/index.tsx b/web/app/components/app/configuration/base/operation-btn/index.tsx index aba35cded2..db19d2976e 100644 --- a/web/app/components/app/configuration/base/operation-btn/index.tsx +++ b/web/app/components/app/configuration/base/operation-btn/index.tsx @@ -6,7 +6,7 @@ import { RiAddLine, RiEditLine, } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' export type IOperationBtnProps = { diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 5bf2f177ff..6492864ce2 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -14,7 +14,7 @@ import s from './style.module.css' import MessageTypeSelector from './message-type-selector' import ConfirmAddVar from './confirm-add-var' import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { PromptRole, PromptVariable } from '@/models/debug' import { Copy, diff --git a/web/app/components/app/configuration/config-prompt/message-type-selector.tsx b/web/app/components/app/configuration/config-prompt/message-type-selector.tsx index 17b3ecb2f1..71f3e6ee5f 100644 --- a/web/app/components/app/configuration/config-prompt/message-type-selector.tsx +++ b/web/app/components/app/configuration/config-prompt/message-type-selector.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React from 'react' import { useBoolean, useClickAway } from 'ahooks' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PromptRole } from '@/models/debug' import { ChevronSelectorVertical } from '@/app/components/base/icons/src/vender/line/arrows' type Props = { diff --git a/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx b/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx index 9e10db93ae..90a19c883a 100644 --- a/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx +++ b/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx @@ -2,7 +2,7 @@ import React, { useCallback, useEffect, useState } from 'react' import type { FC } from 'react' import { useDebounceFn } from 'ahooks' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { className?: string diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 68bf6dd7c2..e4c21b0cbc 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -7,7 +7,7 @@ import { produce } from 'immer' import { useContext } from 'use-context-selector' import ConfirmAddVar from './confirm-add-var' import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { PromptVariable } from '@/models/debug' import Tooltip from '@/app/components/base/tooltip' import { AppModeEnum } from '@/types/app' diff --git a/web/app/components/app/configuration/config-var/config-modal/field.tsx b/web/app/components/app/configuration/config-var/config-modal/field.tsx index b24e0be6ce..76d228358a 100644 --- a/web/app/components/app/configuration/config-var/config-modal/field.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/field.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useTranslation } from 'react-i18next' type Props = { diff --git a/web/app/components/app/configuration/config-var/config-modal/type-select.tsx b/web/app/components/app/configuration/config-var/config-modal/type-select.tsx index 2b52991d4a..53d59eb24b 100644 --- a/web/app/components/app/configuration/config-var/config-modal/type-select.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/type-select.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import React, { useState } from 'react' import { ChevronDownIcon } from '@heroicons/react/20/solid' -import classNames from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, @@ -10,7 +9,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon' import type { InputVarType } from '@/app/components/workflow/types' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Badge from '@/app/components/base/badge' import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils' @@ -47,7 +46,7 @@ const TypeSelector: FC = ({ > !readonly && setOpen(v => !v)} className='w-full'>
@@ -69,7 +68,7 @@ const TypeSelector: FC = ({
{items.map((item: Item) => (
{ + const actual = jest.requireActual('use-context-selector') + return { + ...actual, + useContext: (context: unknown) => mockUseContext(context), + } +}) + +const mockUseFeatures = jest.fn() +const mockUseFeaturesStore = jest.fn() +jest.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: (selector: (state: FeatureStoreState) => any) => mockUseFeatures(selector), + useFeaturesStore: () => mockUseFeaturesStore(), +})) + +const defaultFile: FileUpload = { + enabled: false, + allowed_file_types: [], + allowed_file_upload_methods: [TransferMethod.local_file, TransferMethod.remote_url], + number_limits: 3, + image: { + enabled: false, + detail: Resolution.low, + number_limits: 3, + transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url], + }, +} + +let featureStoreState: FeatureStoreState +let setFeaturesMock: jest.Mock + +const setupFeatureStore = (fileOverrides: Partial = {}) => { + const mergedFile: FileUpload = { + ...defaultFile, + ...fileOverrides, + image: { + ...defaultFile.image, + ...fileOverrides.image, + }, + } + featureStoreState = { + features: { + file: mergedFile, + }, + setFeatures: jest.fn(), + showFeaturesModal: false, + setShowFeaturesModal: jest.fn(), + } + setFeaturesMock = featureStoreState.setFeatures as jest.Mock + mockUseFeaturesStore.mockReturnValue({ + getState: () => featureStoreState, + }) + mockUseFeatures.mockImplementation(selector => selector(featureStoreState)) +} + +const getLatestFileConfig = () => { + expect(setFeaturesMock).toHaveBeenCalled() + const latestFeatures = setFeaturesMock.mock.calls[setFeaturesMock.mock.calls.length - 1][0] as { file: FileUpload } + return latestFeatures.file +} + +beforeEach(() => { + jest.clearAllMocks() + mockUseContext.mockReturnValue({ + isShowVisionConfig: true, + isAllowVideoUpload: false, + }) + setupFeatureStore() +}) + +// ConfigVision handles toggling file upload types + visibility rules. +describe('ConfigVision', () => { + it('should not render when vision configuration is hidden', () => { + mockUseContext.mockReturnValue({ + isShowVisionConfig: false, + isAllowVideoUpload: false, + }) + + render() + + expect(screen.queryByText('appDebug.vision.name')).not.toBeInTheDocument() + }) + + it('should show the toggle and parameter controls when visible', () => { + render() + + expect(screen.getByText('appDebug.vision.name')).toBeInTheDocument() + expect(screen.getByRole('switch')).toHaveAttribute('aria-checked', 'false') + }) + + it('should enable both image and video uploads when toggled on with video support', async () => { + const user = userEvent.setup() + mockUseContext.mockReturnValue({ + isShowVisionConfig: true, + isAllowVideoUpload: true, + }) + setupFeatureStore({ + allowed_file_types: [], + }) + + render() + await user.click(screen.getByRole('switch')) + + const updatedFile = getLatestFileConfig() + expect(updatedFile.allowed_file_types).toEqual([SupportUploadFileTypes.image, SupportUploadFileTypes.video]) + expect(updatedFile.image?.enabled).toBe(true) + expect(updatedFile.enabled).toBe(true) + }) + + it('should disable image and video uploads when toggled off and no other types remain', async () => { + const user = userEvent.setup() + mockUseContext.mockReturnValue({ + isShowVisionConfig: true, + isAllowVideoUpload: true, + }) + setupFeatureStore({ + allowed_file_types: [SupportUploadFileTypes.image, SupportUploadFileTypes.video], + enabled: true, + image: { + enabled: true, + }, + }) + + render() + await user.click(screen.getByRole('switch')) + + const updatedFile = getLatestFileConfig() + expect(updatedFile.allowed_file_types).toEqual([]) + expect(updatedFile.enabled).toBe(false) + expect(updatedFile.image?.enabled).toBe(false) + }) + + it('should keep file uploads enabled when other file types remain after disabling vision', async () => { + const user = userEvent.setup() + mockUseContext.mockReturnValue({ + isShowVisionConfig: true, + isAllowVideoUpload: false, + }) + setupFeatureStore({ + allowed_file_types: [SupportUploadFileTypes.image, SupportUploadFileTypes.document], + enabled: true, + image: { enabled: true }, + }) + + render() + await user.click(screen.getByRole('switch')) + + const updatedFile = getLatestFileConfig() + expect(updatedFile.allowed_file_types).toEqual([SupportUploadFileTypes.document]) + expect(updatedFile.enabled).toBe(true) + expect(updatedFile.image?.enabled).toBe(false) + }) +}) + +// ParamConfig exposes ParamConfigContent via an inline trigger. +describe('ParamConfig', () => { + it('should toggle parameter panel when clicking the settings button', async () => { + setupFeatureStore() + const user = userEvent.setup() + + render() + + expect(screen.queryByText('appDebug.vision.visionSettings.title')).not.toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'appDebug.voice.settings' })) + + expect(await screen.findByText('appDebug.vision.visionSettings.title')).toBeInTheDocument() + }) +}) + +// ParamConfigContent manages resolution, upload source, and count limits. +describe('ParamConfigContent', () => { + it('should set resolution to high when the corresponding option is selected', async () => { + const user = userEvent.setup() + setupFeatureStore({ + image: { detail: Resolution.low }, + }) + + render() + + await user.click(screen.getByText('appDebug.vision.visionSettings.high')) + + const updatedFile = getLatestFileConfig() + expect(updatedFile.image?.detail).toBe(Resolution.high) + }) + + it('should switch upload method to local only', async () => { + const user = userEvent.setup() + setupFeatureStore({ + allowed_file_upload_methods: [TransferMethod.local_file, TransferMethod.remote_url], + }) + + render() + + await user.click(screen.getByText('appDebug.vision.visionSettings.localUpload')) + + const updatedFile = getLatestFileConfig() + expect(updatedFile.allowed_file_upload_methods).toEqual([TransferMethod.local_file]) + expect(updatedFile.image?.transfer_methods).toEqual([TransferMethod.local_file]) + }) + + it('should update upload limit value when input changes', async () => { + setupFeatureStore({ + number_limits: 2, + }) + + render() + const input = screen.getByRole('spinbutton') as HTMLInputElement + fireEvent.change(input, { target: { value: '4' } }) + + const updatedFile = getLatestFileConfig() + expect(updatedFile.number_limits).toBe(4) + expect(updatedFile.image?.number_limits).toBe(4) + }) +}) diff --git a/web/app/components/app/configuration/config-vision/param-config.tsx b/web/app/components/app/configuration/config-vision/param-config.tsx index 5e4aac6a25..e6af188052 100644 --- a/web/app/components/app/configuration/config-vision/param-config.tsx +++ b/web/app/components/app/configuration/config-vision/param-config.tsx @@ -10,7 +10,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const ParamsConfig: FC = () => { const { t } = useTranslation() diff --git a/web/app/components/app/configuration/config/agent-setting-button.spec.tsx b/web/app/components/app/configuration/config/agent-setting-button.spec.tsx new file mode 100644 index 0000000000..db70865e51 --- /dev/null +++ b/web/app/components/app/configuration/config/agent-setting-button.spec.tsx @@ -0,0 +1,100 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AgentSettingButton from './agent-setting-button' +import type { AgentConfig } from '@/models/debug' +import { AgentStrategy } from '@/types/app' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +let latestAgentSettingProps: any +jest.mock('./agent/agent-setting', () => ({ + __esModule: true, + default: (props: any) => { + latestAgentSettingProps = props + return ( +
+ + +
+ ) + }, +})) + +const createAgentConfig = (overrides: Partial = {}): AgentConfig => ({ + enabled: true, + strategy: AgentStrategy.react, + max_iteration: 3, + tools: [], + ...overrides, +}) + +const setup = (overrides: Partial> = {}) => { + const props: React.ComponentProps = { + isFunctionCall: false, + isChatModel: true, + onAgentSettingChange: jest.fn(), + agentConfig: createAgentConfig(), + ...overrides, + } + + const user = userEvent.setup() + render() + return { props, user } +} + +beforeEach(() => { + jest.clearAllMocks() + latestAgentSettingProps = undefined +}) + +describe('AgentSettingButton', () => { + it('should render button label from translation key', () => { + setup() + + expect(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })).toBeInTheDocument() + }) + + it('should open AgentSetting with the provided configuration when clicked', async () => { + const { user, props } = setup({ isFunctionCall: true, isChatModel: false }) + + await user.click(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })) + + expect(screen.getByTestId('agent-setting')).toBeInTheDocument() + expect(latestAgentSettingProps.isFunctionCall).toBe(true) + expect(latestAgentSettingProps.isChatModel).toBe(false) + expect(latestAgentSettingProps.payload).toEqual(props.agentConfig) + }) + + it('should call onAgentSettingChange and close when AgentSetting saves', async () => { + const { user, props } = setup() + + await user.click(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })) + await user.click(screen.getByText('save-agent')) + + expect(props.onAgentSettingChange).toHaveBeenCalledTimes(1) + expect(props.onAgentSettingChange).toHaveBeenCalledWith({ + ...props.agentConfig, + max_iteration: 9, + }) + expect(screen.queryByTestId('agent-setting')).not.toBeInTheDocument() + }) + + it('should close AgentSetting without saving when cancel is triggered', async () => { + const { user, props } = setup() + + await user.click(screen.getByRole('button', { name: 'appDebug.agent.setting.name' })) + await user.click(screen.getByText('cancel-agent')) + + expect(props.onAgentSettingChange).not.toHaveBeenCalled() + expect(screen.queryByTestId('agent-setting')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx index 00c0776718..2ff1034537 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.spec.tsx @@ -4,12 +4,6 @@ import AgentSetting from './index' import { MAX_ITERATIONS_NUM } from '@/config' import type { AgentConfig } from '@/models/debug' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - jest.mock('ahooks', () => { const actual = jest.requireActual('ahooks') return { diff --git a/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx index 6512e11545..6193392026 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Tooltip from '@/app/components/base/tooltip' type Props = { className?: string diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 4793b5fe49..8dfa2f194b 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -25,7 +25,7 @@ import { MAX_TOOLS_NUM } from '@/config' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import Tooltip from '@/app/components/base/tooltip' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import ToolPicker from '@/app/components/workflow/block-selector/tool-picker' import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types' import { canFindTool } from '@/utils' diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index c5947495db..0627666b4c 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -22,7 +22,7 @@ import { CollectionType } from '@/app/components/tools/types' import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList, fetchWorkflowToolList } from '@/service/tools' import I18n from '@/context/i18n' import { getLanguage } from '@/i18n-config/language' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { ToolWithProvider } from '@/app/components/workflow/types' import { AuthCategory, diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index 71a9304d0c..78d7eef029 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -4,7 +4,7 @@ import React from 'react' import copy from 'copy-to-clipboard' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Copy, CopyCheck, diff --git a/web/app/components/app/configuration/config/assistant-type-picker/index.spec.tsx b/web/app/components/app/configuration/config/assistant-type-picker/index.spec.tsx new file mode 100644 index 0000000000..cda24ea045 --- /dev/null +++ b/web/app/components/app/configuration/config/assistant-type-picker/index.spec.tsx @@ -0,0 +1,862 @@ +import React from 'react' +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AssistantTypePicker from './index' +import type { AgentConfig } from '@/models/debug' +import { AgentStrategy } from '@/types/app' + +// Test utilities +const defaultAgentConfig: AgentConfig = { + enabled: true, + max_iteration: 3, + strategy: AgentStrategy.functionCall, + tools: [], +} + +const defaultProps = { + value: 'chat', + disabled: false, + onChange: jest.fn(), + isFunctionCall: true, + isChatModel: true, + agentConfig: defaultAgentConfig, + onAgentSettingChange: jest.fn(), +} + +const renderComponent = (props: Partial> = {}) => { + const mergedProps = { ...defaultProps, ...props } + return render() +} + +// Helper to get option element by description (which is unique per option) +const getOptionByDescription = (descriptionRegex: RegExp) => { + const description = screen.getByText(descriptionRegex) + return description.parentElement as HTMLElement +} + +describe('AssistantTypePicker', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + renderComponent() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + it('should render chat assistant by default when value is "chat"', () => { + // Arrange & Act + renderComponent({ value: 'chat' }) + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + it('should render agent assistant when value is "agent"', () => { + // Arrange & Act + renderComponent({ value: 'agent' }) + + // Assert + expect(screen.getByText(/agentAssistant.name/i)).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should use provided value prop', () => { + // Arrange & Act + renderComponent({ value: 'agent' }) + + // Assert + expect(screen.getByText(/agentAssistant.name/i)).toBeInTheDocument() + }) + + it('should handle agentConfig prop', () => { + // Arrange + const customAgentConfig: AgentConfig = { + enabled: true, + max_iteration: 10, + strategy: AgentStrategy.react, + tools: [], + } + + // Act + expect(() => { + renderComponent({ agentConfig: customAgentConfig }) + }).not.toThrow() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + it('should handle undefined agentConfig prop', () => { + // Arrange & Act + expect(() => { + renderComponent({ agentConfig: undefined }) + }).not.toThrow() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should open dropdown when clicking trigger', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Both options should be visible + await waitFor(() => { + const chatOptions = screen.getAllByText(/chatAssistant.name/i) + const agentOptions = screen.getAllByText(/agentAssistant.name/i) + expect(chatOptions.length).toBeGreaterThan(1) + expect(agentOptions.length).toBeGreaterThan(0) + }) + }) + + it('should call onChange when selecting chat assistant', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + renderComponent({ value: 'agent', onChange }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open and find chat option + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Find and click the chat option by its unique description + const chatOption = getOptionByDescription(/chatAssistant.description/i) + await user.click(chatOption) + + // Assert + expect(onChange).toHaveBeenCalledWith('chat') + }) + + it('should call onChange when selecting agent assistant', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + renderComponent({ value: 'chat', onChange }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open and click agent option + await waitFor(() => { + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + const agentOption = getOptionByDescription(/agentAssistant.description/i) + await user.click(agentOption) + + // Assert + expect(onChange).toHaveBeenCalledWith('agent') + }) + + it('should close dropdown when selecting chat assistant', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent' }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown and select chat + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + const chatOption = getOptionByDescription(/chatAssistant.description/i) + await user.click(chatOption) + + // Assert - Dropdown should close (descriptions should not be visible) + await waitFor(() => { + expect(screen.queryByText(/chatAssistant.description/i)).not.toBeInTheDocument() + }) + }) + + it('should not close dropdown when selecting agent assistant', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'chat' }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown and select agent + await waitFor(() => { + const agentOptions = screen.getAllByText(/agentAssistant.name/i) + expect(agentOptions.length).toBeGreaterThan(0) + }) + + const agentOptions = screen.getAllByText(/agentAssistant.name/i) + await user.click(agentOptions[0]) + + // Assert - Dropdown should remain open (agent settings should be visible) + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + }) + + it('should not call onChange when clicking same value', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + renderComponent({ value: 'chat', onChange }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown and click same option + await waitFor(() => { + const chatOptions = screen.getAllByText(/chatAssistant.name/i) + expect(chatOptions.length).toBeGreaterThan(1) + }) + + const chatOptions = screen.getAllByText(/chatAssistant.name/i) + await user.click(chatOptions[1]) + + // Assert + expect(onChange).not.toHaveBeenCalled() + }) + }) + + // Disabled state + describe('Disabled State', () => { + it('should not respond to clicks when disabled', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + renderComponent({ disabled: true, onChange }) + + // Act - Open dropdown (dropdown can still open when disabled) + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open + await waitFor(() => { + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Act - Try to click an option + const agentOption = getOptionByDescription(/agentAssistant.description/i) + await user.click(agentOption) + + // Assert - onChange should not be called (options are disabled) + expect(onChange).not.toHaveBeenCalled() + }) + + it('should not show agent config UI when disabled', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: true }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Assert - Agent settings option should not be visible + await waitFor(() => { + expect(screen.queryByText(/agent.setting.name/i)).not.toBeInTheDocument() + }) + }) + + it('should show agent config UI when not disabled', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Assert - Agent settings option should be visible + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + }) + }) + + // Agent Settings Modal + describe('Agent Settings Modal', () => { + it('should open agent settings modal when clicking agent config UI', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Click agent settings + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + }) + + it('should not open agent settings when value is not agent', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'chat', disabled: false }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Wait for dropdown to open + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Assert - Agent settings modal should not appear (value is 'chat') + expect(screen.queryByText(/common.operation.save/i)).not.toBeInTheDocument() + }) + + it('should call onAgentSettingChange when saving agent settings', async () => { + // Arrange + const user = userEvent.setup() + const onAgentSettingChange = jest.fn() + renderComponent({ value: 'agent', disabled: false, onAgentSettingChange }) + + // Act - Open dropdown and agent settings + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Wait for modal and click save + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + + const saveButton = screen.getByText(/common.operation.save/i) + await user.click(saveButton) + + // Assert + expect(onAgentSettingChange).toHaveBeenCalledWith(defaultAgentConfig) + }) + + it('should close modal when saving agent settings', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown, agent settings, and save + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + await waitFor(() => { + expect(screen.getByText(/appDebug.agent.setting.name/i)).toBeInTheDocument() + }) + + const saveButton = screen.getByText(/common.operation.save/i) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(screen.queryByText(/common.operation.save/i)).not.toBeInTheDocument() + }) + }) + + it('should close modal when canceling agent settings', async () => { + // Arrange + const user = userEvent.setup() + const onAgentSettingChange = jest.fn() + renderComponent({ value: 'agent', disabled: false, onAgentSettingChange }) + + // Act - Open dropdown, agent settings, and cancel + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + + const cancelButton = screen.getByText(/common.operation.cancel/i) + await user.click(cancelButton) + + // Assert + await waitFor(() => { + expect(screen.queryByText(/common.operation.save/i)).not.toBeInTheDocument() + }) + expect(onAgentSettingChange).not.toHaveBeenCalled() + }) + + it('should close dropdown when opening agent settings', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', disabled: false }) + + // Act - Open dropdown and agent settings + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert - Modal should be open and dropdown should close + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + + // The dropdown should be closed (agent settings description should not be visible) + await waitFor(() => { + const descriptions = screen.queryAllByText(/agent.setting.description/i) + expect(descriptions.length).toBe(0) + }) + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle rapid toggle clicks', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + await user.click(trigger) + await user.click(trigger) + + // Assert - Should not crash + expect(trigger).toBeInTheDocument() + }) + + it('should handle multiple rapid selection changes', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + renderComponent({ value: 'chat', onChange }) + + // Act - Open and select agent + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Click agent option - this stays open because value is 'agent' + const agentOption = getOptionByDescription(/agentAssistant.description/i) + await user.click(agentOption) + + // Assert - onChange should have been called once to switch to agent + await waitFor(() => { + expect(onChange).toHaveBeenCalledTimes(1) + }) + expect(onChange).toHaveBeenCalledWith('agent') + }) + + it('should handle missing callback functions gracefully', async () => { + // Arrange + const user = userEvent.setup() + + // Act & Assert - Should not crash + expect(() => { + renderComponent({ + onChange: undefined!, + onAgentSettingChange: undefined!, + }) + }).not.toThrow() + + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + }) + + it('should handle empty agentConfig', async () => { + // Arrange & Act + expect(() => { + renderComponent({ agentConfig: {} as AgentConfig }) + }).not.toThrow() + + // Assert + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + }) + + describe('should render with different prop combinations', () => { + const combinations = [ + { value: 'chat' as const, disabled: true, isFunctionCall: true, isChatModel: true }, + { value: 'agent' as const, disabled: false, isFunctionCall: false, isChatModel: false }, + { value: 'agent' as const, disabled: true, isFunctionCall: true, isChatModel: false }, + { value: 'chat' as const, disabled: false, isFunctionCall: false, isChatModel: true }, + ] + + it.each(combinations)( + 'value=$value, disabled=$disabled, isFunctionCall=$isFunctionCall, isChatModel=$isChatModel', + (combo) => { + // Arrange & Act + renderComponent(combo) + + // Assert + const expectedText = combo.value === 'agent' ? 'agentAssistant.name' : 'chatAssistant.name' + expect(screen.getByText(new RegExp(expectedText, 'i'))).toBeInTheDocument() + }, + ) + }) + }) + + // Accessibility + describe('Accessibility', () => { + it('should render interactive dropdown items', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Both options should be visible and clickable + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Verify we can interact with option elements using helper function + const chatOption = getOptionByDescription(/chatAssistant.description/i) + const agentOption = getOptionByDescription(/agentAssistant.description/i) + expect(chatOption).toBeInTheDocument() + expect(agentOption).toBeInTheDocument() + }) + }) + + // SelectItem Component + describe('SelectItem Component', () => { + it('should show checked state for selected option', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'chat' }) + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Both options should be visible with radio components + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // The SelectItem components render with different visual states + // based on isChecked prop - we verify both options are rendered + const chatOption = getOptionByDescription(/chatAssistant.description/i) + const agentOption = getOptionByDescription(/agentAssistant.description/i) + expect(chatOption).toBeInTheDocument() + expect(agentOption).toBeInTheDocument() + }) + + it('should render description text', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Descriptions should be visible + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + }) + + it('should show Radio component for each option', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - Radio components should be present (both options visible) + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + }) + }) + + // Agent Setting Integration + describe('AgentSetting Integration', () => { + it('should show function call mode when isFunctionCall is true', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', isFunctionCall: true, isChatModel: false }) + + // Act - Open dropdown and settings modal + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + expect(screen.getByText(/appDebug.agent.agentModeType.functionCall/i)).toBeInTheDocument() + }) + + it('should show built-in prompt when isFunctionCall is false', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent', isFunctionCall: false, isChatModel: true }) + + // Act - Open dropdown and settings modal + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await waitFor(() => { + expect(screen.getByText(/common.operation.save/i)).toBeInTheDocument() + }) + expect(screen.getByText(/tools.builtInPromptTitle/i)).toBeInTheDocument() + }) + + it('should initialize max iteration from agentConfig payload', async () => { + // Arrange + const user = userEvent.setup() + const customConfig: AgentConfig = { + enabled: true, + max_iteration: 10, + strategy: AgentStrategy.react, + tools: [], + } + + renderComponent({ value: 'agent', agentConfig: customConfig }) + + // Act - Open dropdown and settings modal + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettingsTrigger = screen.getByText(/agent.setting.name/i) + await user.click(agentSettingsTrigger) + + // Assert + await screen.findByText(/common.operation.save/i) + const maxIterationInput = await screen.findByRole('spinbutton') + expect(maxIterationInput).toHaveValue(10) + }) + }) + + // Keyboard Navigation + describe('Keyboard Navigation', () => { + it('should support closing dropdown with Escape key', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Press Escape + await user.keyboard('{Escape}') + + // Assert - Dropdown should close + await waitFor(() => { + expect(screen.queryByText(/chatAssistant.description/i)).not.toBeInTheDocument() + }) + }) + + it('should allow keyboard focus on trigger element', () => { + // Arrange + renderComponent() + + // Act - Get trigger and verify it can receive focus + const trigger = screen.getByText(/chatAssistant.name/i) + + // Assert - Element should be focusable + expect(trigger).toBeInTheDocument() + expect(trigger.parentElement).toBeInTheDocument() + }) + + it('should allow keyboard focus on dropdown options', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + }) + + // Get options + const chatOption = getOptionByDescription(/chatAssistant.description/i) + const agentOption = getOptionByDescription(/agentAssistant.description/i) + + // Assert - Options should be focusable + expect(chatOption).toBeInTheDocument() + expect(agentOption).toBeInTheDocument() + + // Verify options can receive focus + act(() => { + chatOption.focus() + }) + expect(document.activeElement).toBe(chatOption) + }) + + it('should maintain keyboard accessibility for all interactive elements', async () => { + // Arrange + const user = userEvent.setup() + renderComponent({ value: 'agent' }) + + // Act - Open dropdown + const trigger = screen.getByText(/agentAssistant.name/i) + await user.click(trigger) + + // Assert - Agent settings button should be focusable + await waitFor(() => { + expect(screen.getByText(/agent.setting.name/i)).toBeInTheDocument() + }) + + const agentSettings = screen.getByText(/agent.setting.name/i) + expect(agentSettings).toBeInTheDocument() + }) + }) + + // ARIA Attributes + describe('ARIA Attributes', () => { + it('should have proper ARIA state for dropdown', async () => { + // Arrange + const user = userEvent.setup() + const { container } = renderComponent() + + // Act - Check initial state + const portalContainer = container.querySelector('[data-state]') + expect(portalContainer).toHaveAttribute('data-state', 'closed') + + // Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - State should change to open + await waitFor(() => { + const openPortal = container.querySelector('[data-state="open"]') + expect(openPortal).toBeInTheDocument() + }) + }) + + it('should have proper data-state attribute', () => { + // Arrange & Act + const { container } = renderComponent() + + // Assert - Portal should have data-state for accessibility + const portalContainer = container.querySelector('[data-state]') + expect(portalContainer).toBeInTheDocument() + expect(portalContainer).toHaveAttribute('data-state') + + // Should start in closed state + expect(portalContainer).toHaveAttribute('data-state', 'closed') + }) + + it('should maintain accessible structure for screen readers', () => { + // Arrange & Act + renderComponent({ value: 'chat' }) + + // Assert - Text content should be accessible + expect(screen.getByText(/chatAssistant.name/i)).toBeInTheDocument() + + // Icons should have proper structure + const { container } = renderComponent() + const icons = container.querySelectorAll('svg') + expect(icons.length).toBeGreaterThan(0) + }) + + it('should provide context through text labels', async () => { + // Arrange + const user = userEvent.setup() + renderComponent() + + // Act - Open dropdown + const trigger = screen.getByText(/chatAssistant.name/i) + await user.click(trigger) + + // Assert - All options should have descriptive text + await waitFor(() => { + expect(screen.getByText(/chatAssistant.description/i)).toBeInTheDocument() + expect(screen.getByText(/agentAssistant.description/i)).toBeInTheDocument() + }) + + // Title text should be visible + expect(screen.getByText(/assistantType.name/i)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx index 3597a6e292..50f16f957a 100644 --- a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx +++ b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx @@ -4,7 +4,7 @@ import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { RiArrowDownSLine } from '@remixicon/react' import AgentSetting from '../agent/agent-setting' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/components/app/configuration/config/automatic/idea-output.tsx b/web/app/components/app/configuration/config/automatic/idea-output.tsx index df4f76c92b..895f74baa3 100644 --- a/web/app/components/app/configuration/config/automatic/idea-output.tsx +++ b/web/app/components/app/configuration/config/automatic/idea-output.tsx @@ -3,7 +3,7 @@ import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid import { useBoolean } from 'ahooks' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Textarea from '@/app/components/base/textarea' import { useTranslation } from 'react-i18next' diff --git a/web/app/components/app/configuration/config/automatic/instruction-editor.tsx b/web/app/components/app/configuration/config/automatic/instruction-editor.tsx index b14ee93313..409f335232 100644 --- a/web/app/components/app/configuration/config/automatic/instruction-editor.tsx +++ b/web/app/components/app/configuration/config/automatic/instruction-editor.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React from 'react' import PromptEditor from '@/app/components/base/prompt-editor' import type { GeneratorType } from './types' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { Node, NodeOutPutVar, ValueSelector } from '@/app/components/workflow/types' import { BlockEnum } from '@/app/components/workflow/types' import { useTranslation } from 'react-i18next' diff --git a/web/app/components/app/configuration/config/automatic/prompt-toast.tsx b/web/app/components/app/configuration/config/automatic/prompt-toast.tsx index 2826cc97c8..c9169f0ad7 100644 --- a/web/app/components/app/configuration/config/automatic/prompt-toast.tsx +++ b/web/app/components/app/configuration/config/automatic/prompt-toast.tsx @@ -1,7 +1,7 @@ import { RiArrowDownSLine, RiSparklingFill } from '@remixicon/react' import { useBoolean } from 'ahooks' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Markdown } from '@/app/components/base/markdown' import { useTranslation } from 'react-i18next' import s from './style.module.css' diff --git a/web/app/components/app/configuration/config/automatic/version-selector.tsx b/web/app/components/app/configuration/config/automatic/version-selector.tsx index c3d3e1d91c..715c1f3c80 100644 --- a/web/app/components/app/configuration/config/automatic/version-selector.tsx +++ b/web/app/components/app/configuration/config/automatic/version-selector.tsx @@ -1,7 +1,7 @@ import React, { useCallback } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' import { useBoolean } from 'ahooks' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' diff --git a/web/app/components/app/configuration/config/config-audio.spec.tsx b/web/app/components/app/configuration/config/config-audio.spec.tsx new file mode 100644 index 0000000000..94eeb87c99 --- /dev/null +++ b/web/app/components/app/configuration/config/config-audio.spec.tsx @@ -0,0 +1,123 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ConfigAudio from './config-audio' +import type { FeatureStoreState } from '@/app/components/base/features/store' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' + +const mockUseContext = jest.fn() +jest.mock('use-context-selector', () => { + const actual = jest.requireActual('use-context-selector') + return { + ...actual, + useContext: (context: unknown) => mockUseContext(context), + } +}) + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +const mockUseFeatures = jest.fn() +const mockUseFeaturesStore = jest.fn() +jest.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: (selector: (state: FeatureStoreState) => any) => mockUseFeatures(selector), + useFeaturesStore: () => mockUseFeaturesStore(), +})) + +type SetupOptions = { + isVisible?: boolean + allowedTypes?: SupportUploadFileTypes[] +} + +let mockFeatureStoreState: FeatureStoreState +let mockSetFeatures: jest.Mock +const mockStore = { + getState: jest.fn(() => mockFeatureStoreState), +} + +const setupFeatureStore = (allowedTypes: SupportUploadFileTypes[] = []) => { + mockSetFeatures = jest.fn() + mockFeatureStoreState = { + features: { + file: { + allowed_file_types: allowedTypes, + enabled: allowedTypes.length > 0, + }, + }, + setFeatures: mockSetFeatures, + showFeaturesModal: false, + setShowFeaturesModal: jest.fn(), + } + mockStore.getState.mockImplementation(() => mockFeatureStoreState) + mockUseFeaturesStore.mockReturnValue(mockStore) + mockUseFeatures.mockImplementation(selector => selector(mockFeatureStoreState)) +} + +const renderConfigAudio = (options: SetupOptions = {}) => { + const { + isVisible = true, + allowedTypes = [], + } = options + setupFeatureStore(allowedTypes) + mockUseContext.mockReturnValue({ + isShowAudioConfig: isVisible, + }) + const user = userEvent.setup() + render() + return { + user, + setFeatures: mockSetFeatures, + } +} + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('ConfigAudio', () => { + it('should not render when the audio configuration is hidden', () => { + renderConfigAudio({ isVisible: false }) + + expect(screen.queryByText('appDebug.feature.audioUpload.title')).not.toBeInTheDocument() + }) + + it('should display the audio toggle state based on feature store data', () => { + renderConfigAudio({ allowedTypes: [SupportUploadFileTypes.audio] }) + + expect(screen.getByText('appDebug.feature.audioUpload.title')).toBeInTheDocument() + expect(screen.getByRole('switch')).toHaveAttribute('aria-checked', 'true') + }) + + it('should enable audio uploads when toggled on', async () => { + const { user, setFeatures } = renderConfigAudio() + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'false') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [SupportUploadFileTypes.audio], + enabled: true, + }), + })) + }) + + it('should disable audio uploads and turn off file feature when last type is removed', async () => { + const { user, setFeatures } = renderConfigAudio({ allowedTypes: [SupportUploadFileTypes.audio] }) + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'true') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [], + enabled: false, + }), + })) + }) +}) diff --git a/web/app/components/app/configuration/config/config-document.spec.tsx b/web/app/components/app/configuration/config/config-document.spec.tsx new file mode 100644 index 0000000000..aeb504fdbd --- /dev/null +++ b/web/app/components/app/configuration/config/config-document.spec.tsx @@ -0,0 +1,119 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ConfigDocument from './config-document' +import type { FeatureStoreState } from '@/app/components/base/features/store' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' + +const mockUseContext = jest.fn() +jest.mock('use-context-selector', () => { + const actual = jest.requireActual('use-context-selector') + return { + ...actual, + useContext: (context: unknown) => mockUseContext(context), + } +}) + +const mockUseFeatures = jest.fn() +const mockUseFeaturesStore = jest.fn() +jest.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: (selector: (state: FeatureStoreState) => any) => mockUseFeatures(selector), + useFeaturesStore: () => mockUseFeaturesStore(), +})) + +type SetupOptions = { + isVisible?: boolean + allowedTypes?: SupportUploadFileTypes[] +} + +let mockFeatureStoreState: FeatureStoreState +let mockSetFeatures: jest.Mock +const mockStore = { + getState: jest.fn(() => mockFeatureStoreState), +} + +const setupFeatureStore = (allowedTypes: SupportUploadFileTypes[] = []) => { + mockSetFeatures = jest.fn() + mockFeatureStoreState = { + features: { + file: { + allowed_file_types: allowedTypes, + enabled: allowedTypes.length > 0, + }, + }, + setFeatures: mockSetFeatures, + showFeaturesModal: false, + setShowFeaturesModal: jest.fn(), + } + mockStore.getState.mockImplementation(() => mockFeatureStoreState) + mockUseFeaturesStore.mockReturnValue(mockStore) + mockUseFeatures.mockImplementation(selector => selector(mockFeatureStoreState)) +} + +const renderConfigDocument = (options: SetupOptions = {}) => { + const { + isVisible = true, + allowedTypes = [], + } = options + setupFeatureStore(allowedTypes) + mockUseContext.mockReturnValue({ + isShowDocumentConfig: isVisible, + }) + const user = userEvent.setup() + render() + return { + user, + setFeatures: mockSetFeatures, + } +} + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('ConfigDocument', () => { + it('should not render when the document configuration is hidden', () => { + renderConfigDocument({ isVisible: false }) + + expect(screen.queryByText('appDebug.feature.documentUpload.title')).not.toBeInTheDocument() + }) + + it('should show document toggle badge when configuration is visible', () => { + renderConfigDocument({ allowedTypes: [SupportUploadFileTypes.document] }) + + expect(screen.getByText('appDebug.feature.documentUpload.title')).toBeInTheDocument() + expect(screen.getByRole('switch')).toHaveAttribute('aria-checked', 'true') + }) + + it('should add document type to allowed list when toggled on', async () => { + const { user, setFeatures } = renderConfigDocument({ allowedTypes: [SupportUploadFileTypes.audio] }) + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'false') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [SupportUploadFileTypes.audio, SupportUploadFileTypes.document], + enabled: true, + }), + })) + }) + + it('should remove document type but keep file feature enabled when other types remain', async () => { + const { user, setFeatures } = renderConfigDocument({ + allowedTypes: [SupportUploadFileTypes.document, SupportUploadFileTypes.audio], + }) + const toggle = screen.getByRole('switch') + + expect(toggle).toHaveAttribute('aria-checked', 'true') + await user.click(toggle) + + expect(setFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ + allowed_file_types: [SupportUploadFileTypes.audio], + enabled: true, + }), + })) + }) +}) diff --git a/web/app/components/app/configuration/config/index.spec.tsx b/web/app/components/app/configuration/config/index.spec.tsx new file mode 100644 index 0000000000..814c52c3d7 --- /dev/null +++ b/web/app/components/app/configuration/config/index.spec.tsx @@ -0,0 +1,254 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import Config from './index' +import type { ModelConfig, PromptVariable } from '@/models/debug' +import * as useContextSelector from 'use-context-selector' +import type { ToolItem } from '@/types/app' +import { AgentStrategy, AppModeEnum, ModelModeType } from '@/types/app' + +jest.mock('use-context-selector', () => { + const actual = jest.requireActual('use-context-selector') + return { + ...actual, + useContext: jest.fn(), + } +}) + +const mockFormattingDispatcher = jest.fn() +jest.mock('../debug/hooks', () => ({ + __esModule: true, + useFormattingChangedDispatcher: () => mockFormattingDispatcher, +})) + +let latestConfigPromptProps: any +jest.mock('@/app/components/app/configuration/config-prompt', () => ({ + __esModule: true, + default: (props: any) => { + latestConfigPromptProps = props + return
+ }, +})) + +let latestConfigVarProps: any +jest.mock('@/app/components/app/configuration/config-var', () => ({ + __esModule: true, + default: (props: any) => { + latestConfigVarProps = props + return
+ }, +})) + +jest.mock('../dataset-config', () => ({ + __esModule: true, + default: () =>
, +})) + +jest.mock('./agent/agent-tools', () => ({ + __esModule: true, + default: () =>
, +})) + +jest.mock('../config-vision', () => ({ + __esModule: true, + default: () =>
, +})) + +jest.mock('./config-document', () => ({ + __esModule: true, + default: () =>
, +})) + +jest.mock('./config-audio', () => ({ + __esModule: true, + default: () =>
, +})) + +let latestHistoryPanelProps: any +jest.mock('../config-prompt/conversation-history/history-panel', () => ({ + __esModule: true, + default: (props: any) => { + latestHistoryPanelProps = props + return
+ }, +})) + +type MockContext = { + mode: AppModeEnum + isAdvancedMode: boolean + modelModeType: ModelModeType + isAgent: boolean + hasSetBlockStatus: { + context: boolean + history: boolean + query: boolean + } + showHistoryModal: jest.Mock + modelConfig: ModelConfig + setModelConfig: jest.Mock + setPrevPromptConfig: jest.Mock +} + +const createPromptVariable = (overrides: Partial = {}): PromptVariable => ({ + key: 'variable', + name: 'Variable', + type: 'string', + ...overrides, +}) + +const createModelConfig = (overrides: Partial = {}): ModelConfig => ({ + provider: 'openai', + model_id: 'gpt-4', + mode: ModelModeType.chat, + configs: { + prompt_template: 'Hello {{variable}}', + prompt_variables: [createPromptVariable({ key: 'existing' })], + }, + chat_prompt_config: null, + completion_prompt_config: null, + opening_statement: null, + more_like_this: null, + suggested_questions: null, + suggested_questions_after_answer: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + retriever_resource: null, + sensitive_word_avoidance: null, + annotation_reply: null, + external_data_tools: null, + system_parameters: { + audio_file_size_limit: 1, + file_size_limit: 1, + image_file_size_limit: 1, + video_file_size_limit: 1, + workflow_file_upload_limit: 1, + }, + dataSets: [], + agentConfig: { + enabled: false, + strategy: AgentStrategy.react, + max_iteration: 1, + tools: [] as ToolItem[], + }, + ...overrides, +}) + +const createContextValue = (overrides: Partial = {}): MockContext => ({ + mode: AppModeEnum.CHAT, + isAdvancedMode: false, + modelModeType: ModelModeType.chat, + isAgent: false, + hasSetBlockStatus: { + context: false, + history: true, + query: false, + }, + showHistoryModal: jest.fn(), + modelConfig: createModelConfig(), + setModelConfig: jest.fn(), + setPrevPromptConfig: jest.fn(), + ...overrides, +}) + +const mockUseContext = useContextSelector.useContext as jest.Mock + +const renderConfig = (contextOverrides: Partial = {}) => { + const contextValue = createContextValue(contextOverrides) + mockUseContext.mockReturnValue(contextValue) + return { + contextValue, + ...render(), + } +} + +beforeEach(() => { + jest.clearAllMocks() + latestConfigPromptProps = undefined + latestConfigVarProps = undefined + latestHistoryPanelProps = undefined +}) + +// Rendering scenarios ensure the layout toggles agent/history specific sections correctly. +describe('Config - Rendering', () => { + it('should render baseline sections without agent specific panels', () => { + renderConfig() + + expect(screen.getByTestId('config-prompt')).toBeInTheDocument() + expect(screen.getByTestId('config-var')).toBeInTheDocument() + expect(screen.getByTestId('dataset-config')).toBeInTheDocument() + expect(screen.getByTestId('config-vision')).toBeInTheDocument() + expect(screen.getByTestId('config-document')).toBeInTheDocument() + expect(screen.getByTestId('config-audio')).toBeInTheDocument() + expect(screen.queryByTestId('agent-tools')).not.toBeInTheDocument() + expect(screen.queryByTestId('history-panel')).not.toBeInTheDocument() + }) + + it('should show AgentTools when app runs in agent mode', () => { + renderConfig({ isAgent: true }) + + expect(screen.getByTestId('agent-tools')).toBeInTheDocument() + }) + + it('should display HistoryPanel only when advanced chat completion values apply', () => { + const showHistoryModal = jest.fn() + renderConfig({ + isAdvancedMode: true, + mode: AppModeEnum.ADVANCED_CHAT, + modelModeType: ModelModeType.completion, + hasSetBlockStatus: { + context: false, + history: false, + query: false, + }, + showHistoryModal, + }) + + expect(screen.getByTestId('history-panel')).toBeInTheDocument() + expect(latestHistoryPanelProps.showWarning).toBe(true) + expect(latestHistoryPanelProps.onShowEditModal).toBe(showHistoryModal) + }) +}) + +// Prompt handling scenarios validate integration between Config and prompt children. +describe('Config - Prompt Handling', () => { + it('should update prompt template and dispatch formatting event when text changes', () => { + const { contextValue } = renderConfig() + const previousVariables = contextValue.modelConfig.configs.prompt_variables + const additions = [createPromptVariable({ key: 'new', name: 'New' })] + + latestConfigPromptProps.onChange('Updated template', additions) + + expect(contextValue.setPrevPromptConfig).toHaveBeenCalledWith(contextValue.modelConfig.configs) + expect(contextValue.setModelConfig).toHaveBeenCalledWith(expect.objectContaining({ + configs: expect.objectContaining({ + prompt_template: 'Updated template', + prompt_variables: [...previousVariables, ...additions], + }), + })) + expect(mockFormattingDispatcher).toHaveBeenCalledTimes(1) + }) + + it('should skip formatting dispatcher when template remains identical', () => { + const { contextValue } = renderConfig() + const unchangedTemplate = contextValue.modelConfig.configs.prompt_template + + latestConfigPromptProps.onChange(unchangedTemplate, [createPromptVariable({ key: 'added' })]) + + expect(contextValue.setPrevPromptConfig).toHaveBeenCalledWith(contextValue.modelConfig.configs) + expect(mockFormattingDispatcher).not.toHaveBeenCalled() + }) + + it('should replace prompt variables when ConfigVar reports updates', () => { + const { contextValue } = renderConfig() + const replacementVariables = [createPromptVariable({ key: 'replacement' })] + + latestConfigVarProps.onPromptVariablesChange(replacementVariables) + + expect(contextValue.setPrevPromptConfig).toHaveBeenCalledWith(contextValue.modelConfig.configs) + expect(contextValue.setModelConfig).toHaveBeenCalledWith(expect.objectContaining({ + configs: expect.objectContaining({ + prompt_variables: replacementVariables, + }), + })) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx index 85d46122a3..7fd7011a56 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -13,7 +13,7 @@ import Drawer from '@/app/components/base/drawer' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import AppIcon from '@/app/components/base/app-icon' type ItemProps = { diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.tsx index ebba9c51cb..80cc50acdf 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.tsx @@ -4,7 +4,7 @@ import React from 'react' import { useTranslation } from 'react-i18next' import type { Props } from './var-picker' import VarPicker from './var-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { BracketsX } from '@/app/components/base/icons/src/vender/line/development' import Tooltip from '@/app/components/base/tooltip' diff --git a/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx b/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx index c443ea0b1f..f5ea2eaa27 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/var-picker.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { ChevronDownIcon } from '@heroicons/react/24/outline' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/components/app/configuration/dataset-config/index.spec.tsx b/web/app/components/app/configuration/dataset-config/index.spec.tsx new file mode 100644 index 0000000000..3c48eca206 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/index.spec.tsx @@ -0,0 +1,1048 @@ +import { render, screen, within } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DatasetConfig from './index' +import type { DataSet } from '@/models/datasets' +import { DataSourceType, DatasetPermission } from '@/models/datasets' +import { AppModeEnum } from '@/types/app' +import { ModelModeType } from '@/types/app' +import { RETRIEVE_TYPE } from '@/types/app' +import { ComparisonOperator, LogicalOperator } from '@/app/components/workflow/nodes/knowledge-retrieval/types' +import type { DatasetConfigs } from '@/models/debug' + +// Mock external dependencies +jest.mock('@/app/components/workflow/nodes/knowledge-retrieval/utils', () => ({ + getMultipleRetrievalConfig: jest.fn(() => ({ + top_k: 4, + score_threshold: 0.7, + reranking_enable: false, + reranking_model: undefined, + reranking_mode: 'reranking_model', + weights: { weight1: 1.0 }, + })), + getSelectedDatasetsMode: jest.fn(() => ({ + allInternal: true, + allExternal: false, + mixtureInternalAndExternal: false, + mixtureHighQualityAndEconomic: false, + inconsistentEmbeddingModel: false, + })), +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: jest.fn(() => ({ + currentModel: { model: 'rerank-model' }, + currentProvider: { provider: 'openai' }, + })), +})) + +jest.mock('@/context/app-context', () => ({ + useSelector: jest.fn((fn: any) => fn({ + userProfile: { + id: 'user-123', + }, + })), +})) + +jest.mock('@/utils/permission', () => ({ + hasEditPermissionForDataset: jest.fn(() => true), +})) + +jest.mock('../debug/hooks', () => ({ + useFormattingChangedDispatcher: jest.fn(() => jest.fn()), +})) + +jest.mock('lodash-es', () => ({ + intersectionBy: jest.fn((...arrays) => { + // Mock realistic intersection behavior based on metadata name + const validArrays = arrays.filter(Array.isArray) + if (validArrays.length === 0) return [] + + // Start with first array and filter down + return validArrays[0].filter((item: any) => { + if (!item || !item.name) return false + + // Only return items that exist in all arrays + return validArrays.every(array => + array.some((otherItem: any) => + otherItem && otherItem.name === item.name, + ), + ) + }) + }), +})) + +jest.mock('uuid', () => ({ + v4: jest.fn(() => 'mock-uuid'), +})) + +// Mock child components +jest.mock('./card-item', () => ({ + __esModule: true, + default: ({ config, onRemove, onSave, editable }: any) => ( +
+ {config.name} + {editable && } + +
+ ), +})) + +jest.mock('./params-config', () => ({ + __esModule: true, + default: ({ disabled, selectedDatasets }: any) => ( + + ), +})) + +jest.mock('./context-var', () => ({ + __esModule: true, + default: ({ value, options, onChange }: any) => ( + + ), +})) + +jest.mock('@/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter', () => ({ + __esModule: true, + default: ({ + metadataList, + metadataFilterMode, + handleMetadataFilterModeChange, + handleAddCondition, + handleRemoveCondition, + handleUpdateCondition, + handleToggleConditionLogicalOperator, + }: any) => ( +
+ {metadataList.length} + + + + + +
+ ), +})) + +// Mock context +const mockConfigContext: any = { + mode: AppModeEnum.CHAT, + modelModeType: ModelModeType.chat, + isAgent: false, + dataSets: [], + setDataSets: jest.fn(), + modelConfig: { + configs: { + prompt_variables: [], + }, + }, + setModelConfig: jest.fn(), + showSelectDataSet: jest.fn(), + datasetConfigs: { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + metadata_filtering_mode: 'disabled' as any, + metadata_filtering_conditions: undefined, + datasets: { + datasets: [], + }, + } as DatasetConfigs, + datasetConfigsRef: { + current: { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + metadata_filtering_mode: 'disabled' as any, + metadata_filtering_conditions: undefined, + datasets: { + datasets: [], + }, + } as DatasetConfigs, + }, + setDatasetConfigs: jest.fn(), + setRerankSettingModalOpen: jest.fn(), +} + +jest.mock('@/context/debug-configuration', () => ({ + __esModule: true, + default: ({ children }: any) => ( +
+ {children} +
+ ), +})) + +jest.mock('use-context-selector', () => ({ + useContext: jest.fn(() => mockConfigContext), +})) + +const createMockDataset = (overrides: Partial = {}): DataSet => { + const defaultDataset: DataSet = { + id: 'dataset-1', + name: 'Test Dataset', + indexing_status: 'completed' as any, + icon_info: { + icon: '📘', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + }, + description: 'Test dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as any, + author_name: 'Test Author', + created_by: 'user-123', + updated_by: 'user-123', + updated_at: Date.now(), + app_count: 0, + doc_form: 'text' as any, + document_count: 10, + total_document_count: 10, + total_available_documents: 10, + word_count: 1000, + provider: 'dify', + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: 'semantic_search' as any, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + }, + retrieval_model: { + search_method: 'semantic_search' as any, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: true, + }, + built_in_field_enabled: true, + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + keyword_number: 3, + pipeline_id: 'pipeline-123', + is_published: true, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...overrides, + } + return defaultDataset +} + +const renderDatasetConfig = (contextOverrides: Partial = {}) => { + const useContextSelector = require('use-context-selector').useContext + const mergedContext = { ...mockConfigContext, ...contextOverrides } + useContextSelector.mockReturnValue(mergedContext) + + return render() +} + +describe('DatasetConfig', () => { + beforeEach(() => { + jest.clearAllMocks() + mockConfigContext.dataSets = [] + mockConfigContext.setDataSets = jest.fn() + mockConfigContext.setModelConfig = jest.fn() + mockConfigContext.setDatasetConfigs = jest.fn() + mockConfigContext.setRerankSettingModalOpen = jest.fn() + }) + + describe('Rendering', () => { + it('should render dataset configuration panel when component mounts', () => { + renderDatasetConfig() + + expect(screen.getByText('appDebug.feature.dataSet.title')).toBeInTheDocument() + }) + + it('should display empty state message when no datasets are configured', () => { + renderDatasetConfig() + + expect(screen.getByText(/no.*data/i)).toBeInTheDocument() + expect(screen.getByTestId('params-config')).toBeDisabled() + }) + + it('should render dataset cards and enable parameters when datasets exist', () => { + const dataset = createMockDataset() + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + expect(screen.getByText(dataset.name)).toBeInTheDocument() + expect(screen.getByTestId('params-config')).not.toBeDisabled() + }) + + it('should show configuration title and add dataset button in header', () => { + renderDatasetConfig() + + expect(screen.getByText('appDebug.feature.dataSet.title')).toBeInTheDocument() + expect(screen.getByText('common.operation.add')).toBeInTheDocument() + }) + + it('should hide parameters configuration when in agent mode', () => { + renderDatasetConfig({ + isAgent: true, + }) + + expect(screen.queryByTestId('params-config')).not.toBeInTheDocument() + }) + }) + + describe('Dataset Management', () => { + it('should open dataset selection modal when add button is clicked', async () => { + const user = userEvent.setup() + renderDatasetConfig() + + const addButton = screen.getByText('common.operation.add') + await user.click(addButton) + + expect(mockConfigContext.showSelectDataSet).toHaveBeenCalledTimes(1) + }) + + it('should remove dataset and update configuration when remove button is clicked', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + renderDatasetConfig({ + dataSets: [dataset], + }) + + const removeButton = screen.getByText('Remove') + await user.click(removeButton) + + expect(mockConfigContext.setDataSets).toHaveBeenCalledWith([]) + // Note: setDatasetConfigs is also called but its exact parameters depend on + // the retrieval config calculation which involves complex mocked utilities + }) + + it('should trigger rerank setting modal when removing dataset requires rerank configuration', async () => { + const user = userEvent.setup() + const { getSelectedDatasetsMode } = require('@/app/components/workflow/nodes/knowledge-retrieval/utils') + + // Mock scenario that triggers rerank modal + getSelectedDatasetsMode.mockReturnValue({ + allInternal: false, + allExternal: true, + mixtureInternalAndExternal: false, + mixtureHighQualityAndEconomic: false, + inconsistentEmbeddingModel: false, + }) + + const dataset = createMockDataset() + renderDatasetConfig({ + dataSets: [dataset], + }) + + const removeButton = screen.getByText('Remove') + await user.click(removeButton) + + expect(mockConfigContext.setRerankSettingModalOpen).toHaveBeenCalledWith(true) + }) + + it('should handle dataset save', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + }) + + // Mock the onSave in card-item component - it will pass the original dataset + const editButton = screen.getByText('Edit') + await user.click(editButton) + + expect(mockConfigContext.setDataSets).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ + id: dataset.id, + name: dataset.name, + editable: true, + }), + ]), + ) + }) + + it('should format datasets with edit permission', () => { + const dataset = createMockDataset({ + created_by: 'user-123', + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + }) + + describe('Context Variables', () => { + it('should show context variable selector in completion mode with datasets', () => { + const dataset = createMockDataset() + renderDatasetConfig({ + mode: AppModeEnum.COMPLETION, + dataSets: [dataset], + modelConfig: { + configs: { + prompt_variables: [ + { key: 'query', name: 'Query', type: 'string', is_context_var: false }, + { key: 'context', name: 'Context', type: 'string', is_context_var: true }, + ], + }, + }, + }) + + expect(screen.getByTestId('context-var')).toBeInTheDocument() + // Should find the selected context variable in the options + expect(screen.getByText('Select context variable')).toBeInTheDocument() + }) + + it('should not show context variable selector in chat mode', () => { + const dataset = createMockDataset() + renderDatasetConfig({ + mode: AppModeEnum.CHAT, + dataSets: [dataset], + modelConfig: { + configs: { + prompt_variables: [ + { key: 'query', name: 'Query', type: 'string', is_context_var: false }, + ], + }, + }, + }) + + expect(screen.queryByTestId('context-var')).not.toBeInTheDocument() + }) + + it('should handle context variable selection', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + renderDatasetConfig({ + mode: AppModeEnum.COMPLETION, + dataSets: [dataset], + modelConfig: { + configs: { + prompt_variables: [ + { key: 'query', name: 'Query', type: 'string', is_context_var: false }, + { key: 'context', name: 'Context', type: 'string', is_context_var: true }, + ], + }, + }, + }) + + const select = screen.getByTestId('context-var') + await user.selectOptions(select, 'query') + + expect(mockConfigContext.setModelConfig).toHaveBeenCalled() + }) + }) + + describe('Metadata Filtering', () => { + it('should render metadata filter component', () => { + const dataset = createMockDataset({ + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('2') // both 'category' and 'priority' + }) + + it('should handle metadata filter mode change', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + const updatedDatasetConfigs = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_mode: 'disabled' as any, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: updatedDatasetConfigs, + }) + + // Update the ref to match + mockConfigContext.datasetConfigsRef.current = updatedDatasetConfigs + + const select = within(screen.getByTestId('metadata-filter')).getByDisplayValue('Disabled') + await user.selectOptions(select, 'automatic') + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_mode: 'automatic', + }), + ) + }) + + it('should handle adding metadata conditions', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + const baseDatasetConfigs = { + ...mockConfigContext.datasetConfigs, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: baseDatasetConfigs, + }) + + // Update the ref to match + mockConfigContext.datasetConfigsRef.current = baseDatasetConfigs + + const addButton = within(screen.getByTestId('metadata-filter')).getByText('Add Condition') + await user.click(addButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + logical_operator: LogicalOperator.and, + conditions: expect.arrayContaining([ + expect.objectContaining({ + id: 'mock-uuid', + name: 'test', + comparison_operator: ComparisonOperator.is, + }), + ]), + }), + }), + ) + }) + + it('should handle removing metadata conditions', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + const datasetConfigsWithConditions = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_conditions: { + logical_operator: LogicalOperator.and, + conditions: [ + { id: 'condition-id', name: 'test', comparison_operator: ComparisonOperator.is }, + ], + }, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: datasetConfigsWithConditions, + }) + + // Update ref to match datasetConfigs + mockConfigContext.datasetConfigsRef.current = datasetConfigsWithConditions + + const removeButton = within(screen.getByTestId('metadata-filter')).getByText('Remove Condition') + await user.click(removeButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + conditions: [], + }), + }), + ) + }) + + it('should handle updating metadata conditions', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + const datasetConfigsWithConditions = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_conditions: { + logical_operator: LogicalOperator.and, + conditions: [ + { id: 'condition-id', name: 'test', comparison_operator: ComparisonOperator.is }, + ], + }, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: datasetConfigsWithConditions, + }) + + mockConfigContext.datasetConfigsRef.current = datasetConfigsWithConditions + + const updateButton = within(screen.getByTestId('metadata-filter')).getByText('Update Condition') + await user.click(updateButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + conditions: expect.arrayContaining([ + expect.objectContaining({ + name: 'updated', + }), + ]), + }), + }), + ) + }) + + it('should handle toggling logical operator', async () => { + const user = userEvent.setup() + const dataset = createMockDataset() + + const datasetConfigsWithConditions = { + ...mockConfigContext.datasetConfigs, + metadata_filtering_conditions: { + logical_operator: LogicalOperator.and, + conditions: [ + { id: 'condition-id', name: 'test', comparison_operator: ComparisonOperator.is }, + ], + }, + } + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: datasetConfigsWithConditions, + }) + + mockConfigContext.datasetConfigsRef.current = datasetConfigsWithConditions + + const toggleButton = within(screen.getByTestId('metadata-filter')).getByText('Toggle Operator') + await user.click(toggleButton) + + expect(mockConfigContext.setDatasetConfigs).toHaveBeenCalledWith( + expect.objectContaining({ + metadata_filtering_conditions: expect.objectContaining({ + logical_operator: LogicalOperator.or, + }), + }), + ) + }) + }) + + describe('Edge Cases', () => { + it('should handle null doc_metadata gracefully', () => { + const dataset = createMockDataset({ + doc_metadata: undefined, + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('0') + }) + + it('should handle empty doc_metadata array', () => { + const dataset = createMockDataset({ + doc_metadata: [], + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('0') + }) + + it('should handle missing userProfile', () => { + const useSelector = require('@/context/app-context').useSelector + useSelector.mockImplementation((fn: any) => fn({ userProfile: null })) + + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + + it('should handle missing datasetConfigsRef gracefully', () => { + const dataset = createMockDataset() + + // Test with undefined datasetConfigsRef - component renders without immediate error + // The component will fail on interaction due to non-null assertions in handlers + expect(() => { + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigsRef: undefined as any, + }) + }).not.toThrow() + + // The component currently expects datasetConfigsRef to exist for interactions + // This test documents the current behavior and requirements + }) + + it('should handle missing prompt_variables', () => { + // Context var is only shown when datasets exist AND there are prompt_variables + // Test with no datasets to ensure context var is not shown + renderDatasetConfig({ + mode: AppModeEnum.COMPLETION, + dataSets: [], + modelConfig: { + configs: { + prompt_variables: [], + }, + }, + }) + + expect(screen.queryByTestId('context-var')).not.toBeInTheDocument() + }) + }) + + describe('Component Integration', () => { + it('should integrate with card item component', () => { + const datasets = [ + createMockDataset({ id: 'ds1', name: 'Dataset 1' }), + createMockDataset({ id: 'ds2', name: 'Dataset 2' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + expect(screen.getByTestId('card-item-ds1')).toBeInTheDocument() + expect(screen.getByTestId('card-item-ds2')).toBeInTheDocument() + expect(screen.getByText('Dataset 1')).toBeInTheDocument() + expect(screen.getByText('Dataset 2')).toBeInTheDocument() + }) + + it('should integrate with params config component', () => { + const datasets = [ + createMockDataset(), + createMockDataset({ id: 'ds2' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + const paramsConfig = screen.getByTestId('params-config') + expect(paramsConfig).toBeInTheDocument() + expect(paramsConfig).toHaveTextContent('Params (2)') + expect(paramsConfig).not.toBeDisabled() + }) + + it('should integrate with metadata filter component', () => { + const datasets = [ + createMockDataset({ + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'tags', type: 'string' } as any, + ], + }), + createMockDataset({ + id: 'ds2', + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + const metadataFilter = screen.getByTestId('metadata-filter') + expect(metadataFilter).toBeInTheDocument() + // Should show intersection (only 'category') + expect(screen.getByTestId('metadata-list-count')).toHaveTextContent('1') + }) + }) + + describe('Model Configuration', () => { + it('should handle metadata model change', () => { + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: { + ...mockConfigContext.datasetConfigs, + metadata_model_config: { + provider: 'openai', + name: 'gpt-3.5-turbo', + mode: AppModeEnum.CHAT, + completion_params: { temperature: 0.7 }, + }, + }, + }) + + // The component would need to expose this functionality through the metadata filter + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + }) + + it('should handle metadata completion params change', () => { + const dataset = createMockDataset() + + renderDatasetConfig({ + dataSets: [dataset], + datasetConfigs: { + ...mockConfigContext.datasetConfigs, + metadata_model_config: { + provider: 'openai', + name: 'gpt-3.5-turbo', + mode: AppModeEnum.CHAT, + completion_params: { temperature: 0.5, max_tokens: 100 }, + }, + }, + }) + + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + }) + }) + + describe('Permission Handling', () => { + it('should hide edit options when user lacks permission', () => { + const { hasEditPermissionForDataset } = require('@/utils/permission') + hasEditPermissionForDataset.mockReturnValue(false) + + const dataset = createMockDataset({ + created_by: 'other-user', + permission: DatasetPermission.onlyMe, + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + // The editable property should be false when no permission + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + + it('should show readonly state for non-editable datasets', () => { + const { hasEditPermissionForDataset } = require('@/utils/permission') + hasEditPermissionForDataset.mockReturnValue(false) + + const dataset = createMockDataset({ + created_by: 'admin', + permission: DatasetPermission.allTeamMembers, + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + + it('should allow editing when user has partial member permission', () => { + const { hasEditPermissionForDataset } = require('@/utils/permission') + hasEditPermissionForDataset.mockReturnValue(true) + + const dataset = createMockDataset({ + created_by: 'admin', + permission: DatasetPermission.partialMembers, + partial_member_list: ['user-123'], + }) + + renderDatasetConfig({ + dataSets: [dataset], + }) + + expect(screen.getByTestId(`card-item-${dataset.id}`)).toBeInTheDocument() + }) + }) + + describe('Dataset Reordering and Management', () => { + it('should maintain dataset order after updates', () => { + const datasets = [ + createMockDataset({ id: 'ds1', name: 'Dataset 1' }), + createMockDataset({ id: 'ds2', name: 'Dataset 2' }), + createMockDataset({ id: 'ds3', name: 'Dataset 3' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + // Verify order is maintained + expect(screen.getByText('Dataset 1')).toBeInTheDocument() + expect(screen.getByText('Dataset 2')).toBeInTheDocument() + expect(screen.getByText('Dataset 3')).toBeInTheDocument() + }) + + it('should handle multiple dataset operations correctly', async () => { + const user = userEvent.setup() + const datasets = [ + createMockDataset({ id: 'ds1', name: 'Dataset 1' }), + createMockDataset({ id: 'ds2', name: 'Dataset 2' }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + // Remove first dataset + const removeButton1 = screen.getAllByText('Remove')[0] + await user.click(removeButton1) + + expect(mockConfigContext.setDataSets).toHaveBeenCalledWith([datasets[1]]) + }) + }) + + describe('Complex Configuration Scenarios', () => { + it('should handle multiple retrieval methods in configuration', () => { + const datasets = [ + createMockDataset({ + id: 'ds1', + retrieval_model: { + search_method: 'semantic_search' as any, + reranking_enable: true, + reranking_model: { + reranking_provider_name: 'cohere', + reranking_model_name: 'rerank-v3.5', + }, + top_k: 5, + score_threshold_enabled: true, + score_threshold: 0.8, + }, + }), + createMockDataset({ + id: 'ds2', + retrieval_model: { + search_method: 'full_text_search' as any, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + }, + }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + expect(screen.getByTestId('params-config')).toHaveTextContent('Params (2)') + }) + + it('should handle external knowledge base integration', () => { + const externalDataset = createMockDataset({ + provider: 'notion', + external_knowledge_info: { + external_knowledge_id: 'notion-123', + external_knowledge_api_id: 'api-456', + external_knowledge_api_name: 'Notion Integration', + external_knowledge_api_endpoint: 'https://api.notion.com', + }, + }) + + renderDatasetConfig({ + dataSets: [externalDataset], + }) + + expect(screen.getByTestId(`card-item-${externalDataset.id}`)).toBeInTheDocument() + expect(screen.getByText(externalDataset.name)).toBeInTheDocument() + }) + }) + + describe('Performance and Error Handling', () => { + it('should handle large dataset lists efficiently', () => { + // Create many datasets to test performance + const manyDatasets = Array.from({ length: 50 }, (_, i) => + createMockDataset({ + id: `ds-${i}`, + name: `Dataset ${i}`, + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + ) + + renderDatasetConfig({ + dataSets: manyDatasets, + }) + + expect(screen.getByTestId('params-config')).toHaveTextContent('Params (50)') + }) + + it('should handle metadata intersection calculation efficiently', () => { + const datasets = [ + createMockDataset({ + id: 'ds1', + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'tags', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + createMockDataset({ + id: 'ds2', + doc_metadata: [ + { name: 'category', type: 'string' } as any, + { name: 'status', type: 'string' } as any, + { name: 'priority', type: 'number' } as any, + ], + }), + ] + + renderDatasetConfig({ + dataSets: datasets, + }) + + // Should calculate intersection correctly + expect(screen.getByTestId('metadata-filter')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx new file mode 100644 index 0000000000..e44eba6c03 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.spec.tsx @@ -0,0 +1,390 @@ +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ConfigContent from './config-content' +import type { DataSet } from '@/models/datasets' +import { ChunkingMode, DataSourceType, DatasetPermission, RerankingModeEnum, WeightedScoreEnum } from '@/models/datasets' +import type { DatasetConfigs } from '@/models/debug' +import { RETRIEVE_METHOD, RETRIEVE_TYPE } from '@/types/app' +import type { RetrievalConfig } from '@/types/app' +import Toast from '@/app/components/base/toast' +import type { IndexingType } from '@/app/components/datasets/create/step-two' +import { + useCurrentProviderAndModel, + useModelListAndDefaultModelAndCurrentProviderAndModel, +} from '@/app/components/header/account-setting/model-provider-page/hooks' + +jest.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => { + type Props = { + defaultModel?: { provider: string; model: string } + onSelect?: (model: { provider: string; model: string }) => void + } + + const MockModelSelector = ({ defaultModel, onSelect }: Props) => ( + + ) + + return { + __esModule: true, + default: MockModelSelector, + } +}) + +jest.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({ + __esModule: true, + default: () =>
, +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: jest.fn(), + useCurrentProviderAndModel: jest.fn(), +})) + +const mockedUseModelListAndDefaultModelAndCurrentProviderAndModel = useModelListAndDefaultModelAndCurrentProviderAndModel as jest.MockedFunction +const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as jest.MockedFunction + +let toastNotifySpy: jest.SpyInstance + +const baseRetrievalConfig: RetrievalConfig = { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: 'provider', + reranking_model_name: 'rerank-model', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0, +} + +const defaultIndexingTechnique: IndexingType = 'high_quality' as IndexingType + +const createDataset = (overrides: Partial = {}): DataSet => { + const { + retrieval_model, + retrieval_model_dict, + icon_info, + ...restOverrides + } = overrides + + const resolvedRetrievalModelDict = { + ...baseRetrievalConfig, + ...retrieval_model_dict, + } + const resolvedRetrievalModel = { + ...baseRetrievalConfig, + ...(retrieval_model ?? retrieval_model_dict), + } + + const defaultIconInfo = { + icon: '📘', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + } + + const resolvedIconInfo = ('icon_info' in overrides) + ? icon_info + : defaultIconInfo + + return { + id: 'dataset-id', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: resolvedIconInfo as DataSet['icon_info'], + description: 'A test dataset', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: defaultIndexingTechnique, + author_name: 'author', + created_by: 'creator', + updated_by: 'updater', + updated_at: 0, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + total_available_documents: 0, + word_count: 0, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: resolvedRetrievalModelDict, + retrieval_model: resolvedRetrievalModel, + tags: [], + external_knowledge_info: { + external_knowledge_id: 'external-id', + external_knowledge_api_id: 'api-id', + external_knowledge_api_name: 'api-name', + external_knowledge_api_endpoint: 'https://endpoint', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: true, + }, + built_in_field_enabled: true, + doc_metadata: [], + keyword_number: 3, + pipeline_id: 'pipeline-id', + is_published: true, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...restOverrides, + } +} + +const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { + return { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0, + datasets: { + datasets: [], + }, + reranking_mode: RerankingModeEnum.WeightedScore, + weights: { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: 0.5, + embedding_provider_name: 'openai', + embedding_model_name: 'text-embedding', + }, + keyword_setting: { + keyword_weight: 0.5, + }, + }, + reranking_enable: false, + ...overrides, + } +} + +describe('ConfigContent', () => { + beforeEach(() => { + jest.clearAllMocks() + toastNotifySpy = jest.spyOn(Toast, 'notify').mockImplementation(() => ({})) + mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ + modelList: [], + defaultModel: undefined, + currentProvider: undefined, + currentModel: undefined, + }) + mockedUseCurrentProviderAndModel.mockReturnValue({ + currentProvider: undefined, + currentModel: undefined, + }) + }) + + afterEach(() => { + toastNotifySpy.mockRestore() + }) + + // State management + describe('Effects', () => { + it('should normalize oneWay retrieval mode to multiWay', async () => { + // Arrange + const onChange = jest.fn() + const datasetConfigs = createDatasetConfigs({ retrieval_model: RETRIEVE_TYPE.oneWay }) + + // Act + render() + + // Assert + await waitFor(() => { + expect(onChange).toHaveBeenCalled() + }) + const [nextConfigs] = onChange.mock.calls[0] + expect(nextConfigs.retrieval_model).toBe(RETRIEVE_TYPE.multiWay) + }) + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render weighted score panel when datasets are high-quality and consistent', () => { + // Arrange + const onChange = jest.fn() + const datasetConfigs = createDatasetConfigs({ + reranking_mode: RerankingModeEnum.WeightedScore, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'high_quality' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('dataset.weightedScore.title')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.semantic')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.keyword')).toBeInTheDocument() + }) + }) + + // User interactions + describe('User Interactions', () => { + it('should update weights when user changes weighted score slider', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + const datasetConfigs = createDatasetConfigs({ + reranking_mode: RerankingModeEnum.WeightedScore, + weights: { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: 0.5, + embedding_provider_name: 'openai', + embedding_model_name: 'text-embedding', + }, + keyword_setting: { + keyword_weight: 0.5, + }, + }, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'high_quality' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + + const weightedScoreSlider = screen.getAllByRole('slider') + .find(slider => slider.getAttribute('aria-valuemax') === '1') + expect(weightedScoreSlider).toBeDefined() + await user.click(weightedScoreSlider!) + const callsBefore = onChange.mock.calls.length + await user.keyboard('{ArrowRight}') + + // Assert + expect(onChange.mock.calls.length).toBeGreaterThan(callsBefore) + const [nextConfigs] = onChange.mock.calls.at(-1) ?? [] + expect(nextConfigs?.weights?.vector_setting.vector_weight).toBeCloseTo(0.6, 5) + expect(nextConfigs?.weights?.keyword_setting.keyword_weight).toBeCloseTo(0.4, 5) + }) + + it('should warn when switching to rerank model mode without a valid model', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + const datasetConfigs = createDatasetConfigs({ + reranking_mode: RerankingModeEnum.WeightedScore, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'high_quality' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + await user.click(screen.getByText('common.modelProvider.rerankModel.key')) + + // Assert + expect(toastNotifySpy).toHaveBeenCalledWith({ + type: 'error', + message: 'workflow.errorMsg.rerankModelRequired', + }) + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_mode: RerankingModeEnum.RerankingModel, + }), + ) + }) + + it('should warn when enabling rerank without a valid model in manual toggle mode', async () => { + // Arrange + const user = userEvent.setup() + const onChange = jest.fn() + const datasetConfigs = createDatasetConfigs({ + reranking_enable: false, + }) + const selectedDatasets: DataSet[] = [ + createDataset({ + indexing_technique: 'economy' as IndexingType, + provider: 'dify', + embedding_model: 'text-embedding', + embedding_model_provider: 'openai', + retrieval_model_dict: { + ...baseRetrievalConfig, + search_method: RETRIEVE_METHOD.semantic, + }, + }), + ] + + // Act + render( + , + ) + await user.click(screen.getByRole('switch')) + + // Assert + expect(toastNotifySpy).toHaveBeenCalledWith({ + type: 'error', + message: 'workflow.errorMsg.rerankModelRequired', + }) + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + reranking_enable: true, + }), + ) + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 8e06d6c901..c7a43fbfbd 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -20,7 +20,7 @@ import type { DataSet, } from '@/models/datasets' import { RerankingModeEnum } from '@/models/datasets' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' import Switch from '@/app/components/base/switch' import Toast from '@/app/components/base/toast' diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx new file mode 100644 index 0000000000..c432ca68e2 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx @@ -0,0 +1,265 @@ +import * as React from 'react' +import { render, screen, waitFor, within } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import ParamsConfig from './index' +import ConfigContext from '@/context/debug-configuration' +import type { DatasetConfigs } from '@/models/debug' +import { RerankingModeEnum } from '@/models/datasets' +import { RETRIEVE_TYPE } from '@/types/app' +import Toast from '@/app/components/base/toast' +import { + useCurrentProviderAndModel, + useModelListAndDefaultModelAndCurrentProviderAndModel, +} from '@/app/components/header/account-setting/model-provider-page/hooks' + +jest.mock('@headlessui/react', () => ({ + Dialog: ({ children, className }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + Transition: ({ show, children }: { show: boolean; children: React.ReactNode }) => (show ? <>{children} : null), + TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}, + Switch: ({ checked, onChange, children, ...props }: { checked: boolean; onChange?: (value: boolean) => void; children?: React.ReactNode }) => ( + + ), +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModelAndCurrentProviderAndModel: jest.fn(), + useCurrentProviderAndModel: jest.fn(), +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => { + type Props = { + defaultModel?: { provider: string; model: string } + onSelect?: (model: { provider: string; model: string }) => void + } + + const MockModelSelector = ({ defaultModel, onSelect }: Props) => ( + + ) + + return { + __esModule: true, + default: MockModelSelector, + } +}) + +jest.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({ + __esModule: true, + default: () =>
, +})) + +const mockedUseModelListAndDefaultModelAndCurrentProviderAndModel = useModelListAndDefaultModelAndCurrentProviderAndModel as jest.MockedFunction +const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as jest.MockedFunction +let toastNotifySpy: jest.SpyInstance + +const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { + return { + retrieval_model: RETRIEVE_TYPE.multiWay, + reranking_model: { + reranking_provider_name: 'provider', + reranking_model_name: 'rerank-model', + }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0, + datasets: { + datasets: [], + }, + reranking_enable: false, + reranking_mode: RerankingModeEnum.RerankingModel, + ...overrides, + } +} + +const renderParamsConfig = ({ + datasetConfigs = createDatasetConfigs(), + initialModalOpen = false, + disabled, +}: { + datasetConfigs?: DatasetConfigs + initialModalOpen?: boolean + disabled?: boolean +} = {}) => { + const Wrapper = ({ children }: { children: React.ReactNode }) => { + const [datasetConfigsState, setDatasetConfigsState] = React.useState(datasetConfigs) + const [modalOpen, setModalOpen] = React.useState(initialModalOpen) + + const contextValue = { + datasetConfigs: datasetConfigsState, + setDatasetConfigs: (next: DatasetConfigs) => { + setDatasetConfigsState(next) + }, + rerankSettingModalOpen: modalOpen, + setRerankSettingModalOpen: (open: boolean) => { + setModalOpen(open) + }, + } as unknown as React.ComponentProps['value'] + + return ( + + {children} + + ) + } + + return render( + , + { wrapper: Wrapper }, + ) +} + +describe('dataset-config/params-config', () => { + beforeEach(() => { + jest.clearAllMocks() + jest.useRealTimers() + toastNotifySpy = jest.spyOn(Toast, 'notify').mockImplementation(() => ({})) + mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ + modelList: [], + defaultModel: undefined, + currentProvider: undefined, + currentModel: undefined, + }) + mockedUseCurrentProviderAndModel.mockReturnValue({ + currentProvider: undefined, + currentModel: undefined, + }) + }) + + afterEach(() => { + toastNotifySpy.mockRestore() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should disable settings trigger when disabled is true', () => { + // Arrange + renderParamsConfig({ disabled: true }) + + // Assert + expect(screen.getByRole('button', { name: 'dataset.retrievalSettings' })).toBeDisabled() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should open modal and persist changes when save is clicked', async () => { + // Arrange + renderParamsConfig() + const user = userEvent.setup() + + // Act + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const dialogScope = within(dialog) + + const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + await user.click(incrementButtons[0]) + + await waitFor(() => { + const [topKInput] = dialogScope.getAllByRole('spinbutton') + expect(topKInput).toHaveValue(5) + }) + + await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) + + await waitFor(() => { + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const reopenedDialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const reopenedScope = within(reopenedDialog) + const [reopenedTopKInput] = reopenedScope.getAllByRole('spinbutton') + + // Assert + expect(reopenedTopKInput).toHaveValue(5) + }) + + it('should discard changes when cancel is clicked', async () => { + // Arrange + renderParamsConfig() + const user = userEvent.setup() + + // Act + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const dialogScope = within(dialog) + + const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + await user.click(incrementButtons[0]) + + await waitFor(() => { + const [topKInput] = dialogScope.getAllByRole('spinbutton') + expect(topKInput).toHaveValue(5) + }) + + const cancelButton = await dialogScope.findByRole('button', { name: 'common.operation.cancel' }) + await user.click(cancelButton) + await waitFor(() => { + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + + // Re-open and verify the original value remains. + await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) + const reopenedDialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const reopenedScope = within(reopenedDialog) + const [reopenedTopKInput] = reopenedScope.getAllByRole('spinbutton') + + // Assert + expect(reopenedTopKInput).toHaveValue(4) + }) + + it('should prevent saving when rerank model is required but invalid', async () => { + // Arrange + renderParamsConfig({ + datasetConfigs: createDatasetConfigs({ + reranking_enable: true, + reranking_mode: RerankingModeEnum.RerankingModel, + }), + initialModalOpen: true, + }) + const user = userEvent.setup() + + // Act + const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) + const dialogScope = within(dialog) + await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(toastNotifySpy).toHaveBeenCalledWith({ + type: 'error', + message: 'appDebug.datasetConfig.rerankModelRequired', + }) + expect(screen.getByRole('dialog')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index df2b4293c4..24da958217 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -4,7 +4,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { RiEqualizer2Line } from '@remixicon/react' import ConfigContent from './config-content' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import ConfigContext from '@/context/debug-configuration' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx new file mode 100644 index 0000000000..e7b1eb8421 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.spec.tsx @@ -0,0 +1,81 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import WeightedScore from './weighted-score' + +describe('WeightedScore', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render semantic and keyword weights', () => { + // Arrange + const onChange = jest.fn() + const value = { value: [0.3, 0.7] } + + // Act + render() + + // Assert + expect(screen.getByText('dataset.weightedScore.semantic')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.keyword')).toBeInTheDocument() + expect(screen.getByText('0.3')).toBeInTheDocument() + expect(screen.getByText('0.7')).toBeInTheDocument() + }) + + it('should format a weight of 1 as 1.0', () => { + // Arrange + const onChange = jest.fn() + const value = { value: [1, 0] } + + // Act + render() + + // Assert + expect(screen.getByText('1.0')).toBeInTheDocument() + expect(screen.getByText('0')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should emit complementary weights when the slider value changes', async () => { + // Arrange + const onChange = jest.fn() + const value = { value: [0.5, 0.5] } + const user = userEvent.setup() + render() + + // Act + await user.tab() + const slider = screen.getByRole('slider') + expect(slider).toHaveFocus() + const callsBefore = onChange.mock.calls.length + await user.keyboard('{ArrowRight}') + + // Assert + expect(onChange.mock.calls.length).toBeGreaterThan(callsBefore) + const lastCall = onChange.mock.calls.at(-1)?.[0] + expect(lastCall?.value[0]).toBeCloseTo(0.6, 5) + expect(lastCall?.value[1]).toBeCloseTo(0.4, 5) + }) + + it('should not call onChange when readonly is true', async () => { + // Arrange + const onChange = jest.fn() + const value = { value: [0.5, 0.5] } + const user = userEvent.setup() + render() + + // Act + await user.tab() + const slider = screen.getByRole('slider') + expect(slider).toHaveFocus() + await user.keyboard('{ArrowRight}') + + // Assert + expect(onChange).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx index ebfa3b1e12..459623104d 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx @@ -2,7 +2,7 @@ import { memo } from 'react' import { useTranslation } from 'react-i18next' import './weighted-score.css' import Slider from '@/app/components/base/slider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' const formatNumber = (value: number) => { diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 6857c38e1e..f02fdcb5d7 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -10,7 +10,7 @@ import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import AppIcon from '@/app/components/base/app-icon' import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx new file mode 100644 index 0000000000..e2c5307b03 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx @@ -0,0 +1,537 @@ +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SettingsModal from './index' +import { ToastContext } from '@/app/components/base/toast' +import type { DataSet } from '@/models/datasets' +import { ChunkingMode, DataSourceType, DatasetPermission, RerankingModeEnum } from '@/models/datasets' +import { IndexingType } from '@/app/components/datasets/create/step-two' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { updateDatasetSetting } from '@/service/datasets' +import { useMembers } from '@/service/use-common' +import { RETRIEVE_METHOD, type RetrievalConfig } from '@/types/app' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' + +const mockNotify = jest.fn() +const mockOnCancel = jest.fn() +const mockOnSave = jest.fn() +const mockSetShowAccountSettingModal = jest.fn() +let mockIsWorkspaceDatasetOperator = false + +const mockUseModelList = jest.fn() +const mockUseModelListAndDefaultModel = jest.fn() +const mockUseModelListAndDefaultModelAndCurrentProviderAndModel = jest.fn() +const mockUseCurrentProviderAndModel = jest.fn() +const mockCheckShowMultiModalTip = jest.fn() + +jest.mock('ky', () => { + const ky = () => ky + ky.extend = () => ky + ky.create = () => ky + return { __esModule: true, default: ky } +}) + +jest.mock('@/app/components/datasets/create/step-two', () => ({ + __esModule: true, + IndexingType: { + QUALIFIED: 'high_quality', + ECONOMICAL: 'economy', + }, +})) + +jest.mock('@/service/datasets', () => ({ + updateDatasetSetting: jest.fn(), +})) + +jest.mock('@/service/use-common', () => ({ + __esModule: true, + ...jest.requireActual('@/service/use-common'), + useMembers: jest.fn(), +})) + +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ isCurrentWorkspaceDatasetOperator: mockIsWorkspaceDatasetOperator }), + useSelector: (selector: (value: { userProfile: { id: string; name: string; email: string; avatar_url: string } }) => T) => selector({ + userProfile: { + id: 'user-1', + name: 'User One', + email: 'user@example.com', + avatar_url: 'avatar.png', + }, + }), +})) + +jest.mock('@/context/modal-context', () => ({ + useModalContext: () => ({ + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }), +})) + +jest.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs${path}`, +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + modelProviders: [], + textGenerationModelList: [], + supportRetrievalMethods: [ + RETRIEVE_METHOD.semantic, + RETRIEVE_METHOD.fullText, + RETRIEVE_METHOD.hybrid, + RETRIEVE_METHOD.keywordSearch, + ], + }), +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + __esModule: true, + useModelList: (...args: unknown[]) => mockUseModelList(...args), + useModelListAndDefaultModel: (...args: unknown[]) => mockUseModelListAndDefaultModel(...args), + useModelListAndDefaultModelAndCurrentProviderAndModel: (...args: unknown[]) => + mockUseModelListAndDefaultModelAndCurrentProviderAndModel(...args), + useCurrentProviderAndModel: (...args: unknown[]) => mockUseCurrentProviderAndModel(...args), +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + __esModule: true, + default: ({ defaultModel }: { defaultModel?: { provider: string; model: string } }) => ( +
+ {defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-model'} +
+ ), +})) + +jest.mock('@/app/components/datasets/settings/utils', () => ({ + checkShowMultiModalTip: (...args: unknown[]) => mockCheckShowMultiModalTip(...args), +})) + +const mockUpdateDatasetSetting = updateDatasetSetting as jest.MockedFunction +const mockUseMembers = useMembers as jest.MockedFunction + +const createRetrievalConfig = (overrides: Partial = {}): RetrievalConfig => ({ + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 2, + score_threshold_enabled: false, + score_threshold: 0.5, + reranking_mode: RerankingModeEnum.RerankingModel, + ...overrides, +}) + +const createDataset = (overrides: Partial = {}, retrievalOverrides: Partial = {}): DataSet => { + const retrievalConfig = createRetrievalConfig(retrievalOverrides) + return { + id: 'dataset-id', + name: 'Test Dataset', + indexing_status: 'completed', + icon_info: { + icon: 'icon', + icon_type: 'emoji', + }, + description: 'Description', + permission: DatasetPermission.allTeamMembers, + data_source_type: DataSourceType.FILE, + indexing_technique: IndexingType.QUALIFIED, + author_name: 'Author', + created_by: 'creator', + updated_by: 'updater', + updated_at: 1700000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + total_available_documents: 0, + word_count: 0, + provider: 'internal', + embedding_model: 'embed-model', + embedding_model_provider: 'embed-provider', + embedding_available: true, + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-id', + external_knowledge_api_id: 'ext-api-id', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + doc_metadata: [], + keyword_number: 10, + pipeline_id: 'pipeline-id', + is_published: false, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...overrides, + retrieval_model_dict: { + ...retrievalConfig, + ...overrides.retrieval_model_dict, + }, + retrieval_model: { + ...retrievalConfig, + ...overrides.retrieval_model, + }, + } +} + +const renderWithProviders = (dataset: DataSet) => { + return render( + + + , + ) +} + +const createMemberList = (): DataSet['partial_member_list'] => ([ + 'member-2', +]) + +const renderSettingsModal = async (dataset: DataSet) => { + renderWithProviders(dataset) + await waitFor(() => expect(mockUseMembers).toHaveBeenCalled()) +} + +describe('SettingsModal', () => { + beforeEach(() => { + jest.clearAllMocks() + mockIsWorkspaceDatasetOperator = false + mockUseMembers.mockReturnValue({ + data: { + accounts: [ + { + id: 'user-1', + name: 'User One', + email: 'user@example.com', + avatar: 'avatar.png', + avatar_url: 'avatar.png', + status: 'active', + role: 'owner', + }, + { + id: 'member-2', + name: 'Member Two', + email: 'member@example.com', + avatar: 'avatar.png', + avatar_url: 'avatar.png', + status: 'active', + role: 'editor', + }, + ], + }, + } as ReturnType) + mockUseModelList.mockImplementation((type: ModelTypeEnum) => { + if (type === ModelTypeEnum.rerank) { + return { + data: [ + { + provider: 'rerank-provider', + models: [{ model: 'rerank-model' }], + }, + ], + } + } + return { data: [{ provider: 'embed-provider', models: [{ model: 'embed-model' }] }] } + }) + mockUseModelListAndDefaultModel.mockReturnValue({ modelList: [], defaultModel: null }) + mockUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ defaultModel: null, currentModel: null }) + mockUseCurrentProviderAndModel.mockReturnValue({ currentProvider: null, currentModel: null }) + mockCheckShowMultiModalTip.mockReturnValue(false) + mockUpdateDatasetSetting.mockResolvedValue(createDataset()) + }) + + // Rendering and basic field bindings. + describe('Rendering', () => { + it('should render dataset details when dataset is provided', async () => { + // Arrange + const dataset = createDataset() + + // Act + await renderSettingsModal(dataset) + + // Assert + expect(screen.getByPlaceholderText('datasetSettings.form.namePlaceholder')).toHaveValue('Test Dataset') + expect(screen.getByPlaceholderText('datasetSettings.form.descPlaceholder')).toHaveValue('Description') + }) + + it('should show external knowledge info when dataset is external', async () => { + // Arrange + const dataset = createDataset({ + provider: 'external', + external_knowledge_info: { + external_knowledge_id: 'ext-id-123', + external_knowledge_api_id: 'ext-api-id-123', + external_knowledge_api_name: 'External Knowledge API', + external_knowledge_api_endpoint: 'https://api.external.com', + }, + }) + + // Act + await renderSettingsModal(dataset) + + // Assert + expect(screen.getByText('External Knowledge API')).toBeInTheDocument() + expect(screen.getByText('https://api.external.com')).toBeInTheDocument() + expect(screen.getByText('ext-id-123')).toBeInTheDocument() + }) + }) + + // User interactions that update visible state. + describe('Interactions', () => { + it('should call onCancel when cancel button is clicked', async () => { + // Arrange + const user = userEvent.setup() + + // Act + await renderSettingsModal(createDataset()) + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + // Assert + expect(mockOnCancel).toHaveBeenCalledTimes(1) + }) + + it('should update name input when user types', async () => { + // Arrange + const user = userEvent.setup() + await renderSettingsModal(createDataset()) + + const nameInput = screen.getByPlaceholderText('datasetSettings.form.namePlaceholder') + + // Act + await user.clear(nameInput) + await user.type(nameInput, 'New Dataset Name') + + // Assert + expect(nameInput).toHaveValue('New Dataset Name') + }) + + it('should update description input when user types', async () => { + // Arrange + const user = userEvent.setup() + await renderSettingsModal(createDataset()) + + const descriptionInput = screen.getByPlaceholderText('datasetSettings.form.descPlaceholder') + + // Act + await user.clear(descriptionInput) + await user.type(descriptionInput, 'New description') + + // Assert + expect(descriptionInput).toHaveValue('New description') + }) + + it('should show and dismiss retrieval change tip when indexing method changes', async () => { + // Arrange + const user = userEvent.setup() + const dataset = createDataset({ indexing_technique: IndexingType.ECONOMICAL }) + + // Act + await renderSettingsModal(dataset) + await user.click(screen.getByText('datasetCreation.stepTwo.qualified')) + + // Assert + expect(await screen.findByText('appDebug.datasetConfig.retrieveChangeTip')).toBeInTheDocument() + + // Act + await user.click(screen.getByLabelText('close-retrieval-change-tip')) + + // Assert + await waitFor(() => { + expect(screen.queryByText('appDebug.datasetConfig.retrieveChangeTip')).not.toBeInTheDocument() + }) + }) + + it('should open account setting modal when embedding model tip is clicked', async () => { + // Arrange + const user = userEvent.setup() + + // Act + await renderSettingsModal(createDataset()) + await user.click(screen.getByText('datasetSettings.form.embeddingModelTipLink')) + + // Assert + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.PROVIDER }) + }) + }) + + // Validation guardrails before saving. + describe('Validation', () => { + it('should block save when dataset name is empty', async () => { + // Arrange + const user = userEvent.setup() + await renderSettingsModal(createDataset()) + + const nameInput = screen.getByPlaceholderText('datasetSettings.form.namePlaceholder') + + // Act + await user.clear(nameInput) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'datasetSettings.form.nameError', + })) + expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() + }) + + it('should block save when reranking is enabled without model', async () => { + // Arrange + const user = userEvent.setup() + mockUseModelList.mockReturnValue({ data: [] }) + const dataset = createDataset({}, createRetrievalConfig({ + reranking_enable: true, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + })) + + // Act + await renderSettingsModal(dataset) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appDebug.datasetConfig.rerankModelRequired', + })) + expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() + }) + }) + + // Save flows and side effects. + describe('Save', () => { + it('should save internal dataset changes when form is valid', async () => { + // Arrange + const user = userEvent.setup() + const rerankRetrieval = createRetrievalConfig({ + reranking_enable: true, + reranking_model: { + reranking_provider_name: 'rerank-provider', + reranking_model_name: 'rerank-model', + }, + }) + const dataset = createDataset({ + retrieval_model: rerankRetrieval, + retrieval_model_dict: rerankRetrieval, + }) + + // Act + await renderSettingsModal(dataset) + + const nameInput = screen.getByPlaceholderText('datasetSettings.form.namePlaceholder') + await user.clear(nameInput) + await user.type(nameInput, 'Updated Internal Dataset') + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + await waitFor(() => expect(mockUpdateDatasetSetting).toHaveBeenCalled()) + + expect(mockUpdateDatasetSetting).toHaveBeenCalledWith(expect.objectContaining({ + body: expect.objectContaining({ + name: 'Updated Internal Dataset', + permission: DatasetPermission.allTeamMembers, + }), + })) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + message: 'common.actionMsg.modifiedSuccessfully', + })) + expect(mockOnSave).toHaveBeenCalledWith(expect.objectContaining({ + name: 'Updated Internal Dataset', + retrieval_model_dict: expect.objectContaining({ + reranking_enable: true, + }), + })) + }) + + it('should save external dataset changes when partial members configured', async () => { + // Arrange + const user = userEvent.setup() + const dataset = createDataset({ + provider: 'external', + permission: DatasetPermission.partialMembers, + partial_member_list: createMemberList(), + external_retrieval_model: { + top_k: 5, + score_threshold: 0.3, + score_threshold_enabled: true, + }, + }, { + score_threshold_enabled: true, + score_threshold: 0.8, + }) + + // Act + await renderSettingsModal(dataset) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + await waitFor(() => expect(mockUpdateDatasetSetting).toHaveBeenCalled()) + + expect(mockUpdateDatasetSetting).toHaveBeenCalledWith(expect.objectContaining({ + body: expect.objectContaining({ + permission: DatasetPermission.partialMembers, + external_retrieval_model: expect.objectContaining({ + top_k: 5, + }), + partial_member_list: [ + { + user_id: 'member-2', + role: 'editor', + }, + ], + }), + })) + expect(mockOnSave).toHaveBeenCalledWith(expect.objectContaining({ + retrieval_model_dict: expect.objectContaining({ + score_threshold_enabled: true, + score_threshold: 0.8, + }), + })) + }) + + it('should disable save button while saving', async () => { + // Arrange + const user = userEvent.setup() + mockUpdateDatasetSetting.mockImplementation(() => new Promise(resolve => setTimeout(resolve, 100))) + + // Act + await renderSettingsModal(createDataset()) + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(saveButton).toBeDisabled() + }) + + it('should show error toast when save fails', async () => { + // Arrange + const user = userEvent.setup() + mockUpdateDatasetSetting.mockRejectedValue(new Error('API Error')) + + // Act + await renderSettingsModal(createDataset()) + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) + }) + }) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index cd6e39011e..c191ff5d46 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -1,13 +1,10 @@ import type { FC } from 'react' -import { useMemo, useRef, useState } from 'react' -import { useMount } from 'ahooks' +import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { isEqual } from 'lodash-es' import { RiCloseLine } from '@remixicon/react' -import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import IndexMethod from '@/app/components/datasets/settings/index-method' -import Divider from '@/app/components/base/divider' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' @@ -18,20 +15,17 @@ import { useAppContext } from '@/context/app-context' import { useModalContext } from '@/context/modal-context' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import type { RetrievalConfig } from '@/types/app' -import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' -import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' -import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' -import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import PermissionSelector from '@/app/components/datasets/settings/permission-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { fetchMembers } from '@/service/common' import type { Member } from '@/models/common' import { IndexingType } from '@/app/components/datasets/create/step-two' import { useDocLink } from '@/context/i18n' +import { useMembers } from '@/service/use-common' import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils' +import { RetrievalChangeTip, RetrievalSection } from './retrieval-section' type SettingsModalProps = { currentDataset: DataSet @@ -68,6 +62,7 @@ const SettingsModal: FC = ({ const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(localeCurrentDataset?.external_retrieval_model.score_threshold_enabled ?? false) const [selectedMemberIDs, setSelectedMemberIDs] = useState(currentDataset.partial_member_list || []) const [memberList, setMemberList] = useState([]) + const { data: membersData } = useMembers() const [indexMethod, setIndexMethod] = useState(currentDataset.indexing_technique) const [retrievalConfig, setRetrievalConfig] = useState(localeCurrentDataset?.retrieval_model_dict as RetrievalConfig) @@ -165,17 +160,12 @@ const SettingsModal: FC = ({ } } - const getMembers = async () => { - const { accounts } = await fetchMembers({ url: '/workspaces/current/members', params: {} }) - if (!accounts) + useEffect(() => { + if (!membersData?.accounts) setMemberList([]) else - setMemberList(accounts) - } - - useMount(() => { - getMembers() - }) + setMemberList(membersData.accounts) + }, [membersData]) const showMultiModalTip = useMemo(() => { return checkShowMultiModalTip({ @@ -298,92 +288,37 @@ const SettingsModal: FC = ({ )} {/* Retrieval Method Config */} - {currentDataset?.provider === 'external' - ? <> -
-
-
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- -
-
-
-
-
{t('datasetSettings.form.externalKnowledgeAPI')}
-
-
-
- -
- {currentDataset?.external_knowledge_info.external_knowledge_api_name} -
-
·
-
{currentDataset?.external_knowledge_info.external_knowledge_api_endpoint}
-
-
-
-
-
-
{t('datasetSettings.form.externalKnowledgeID')}
-
-
-
-
{currentDataset?.external_knowledge_info.external_knowledge_id}
-
-
-
-
- - :
-
-
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- {t('datasetSettings.form.retrievalSetting.learnMore')} - {t('datasetSettings.form.retrievalSetting.description')} -
-
-
-
- {indexMethod === IndexingType.QUALIFIED - ? ( - - ) - : ( - - )} -
-
} + {isExternal ? ( + + ) : ( + + )}
- {isRetrievalChanged && !isHideChangedTip && ( -
-
- -
{t('appDebug.datasetConfig.retrieveChangeTip')}
-
-
{ - setIsHideChangedTip(true) - e.stopPropagation() - e.nativeEvent.stopImmediatePropagation() - }}> - -
-
- )} + setIsHideChangedTip(true)} + />
{ + const ky = () => ky + ky.extend = () => ky + ky.create = () => ky + return { __esModule: true, default: ky } +}) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + modelProviders: [], + textGenerationModelList: [], + supportRetrievalMethods: [ + RETRIEVE_METHOD.semantic, + RETRIEVE_METHOD.fullText, + RETRIEVE_METHOD.hybrid, + RETRIEVE_METHOD.keywordSearch, + ], + }), +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + __esModule: true, + useModelListAndDefaultModelAndCurrentProviderAndModel: (...args: unknown[]) => + mockUseModelListAndDefaultModelAndCurrentProviderAndModel(...args), + useModelListAndDefaultModel: (...args: unknown[]) => mockUseModelListAndDefaultModel(...args), + useModelList: (...args: unknown[]) => mockUseModelList(...args), + useCurrentProviderAndModel: (...args: unknown[]) => mockUseCurrentProviderAndModel(...args), +})) + +jest.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + __esModule: true, + default: ({ defaultModel }: { defaultModel?: { provider: string; model: string } }) => ( +
+ {defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-model'} +
+ ), +})) + +jest.mock('@/app/components/datasets/create/step-two', () => ({ + __esModule: true, + IndexingType: { + QUALIFIED: 'high_quality', + ECONOMICAL: 'economy', + }, +})) + +const createRetrievalConfig = (overrides: Partial = {}): RetrievalConfig => ({ + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 2, + score_threshold_enabled: false, + score_threshold: 0.5, + reranking_mode: RerankingModeEnum.RerankingModel, + ...overrides, +}) + +const createDataset = (overrides: Partial = {}, retrievalOverrides: Partial = {}): DataSet => { + const retrievalConfig = createRetrievalConfig(retrievalOverrides) + return { + id: 'dataset-id', + name: 'Test Dataset', + indexing_status: 'completed', + icon_info: { + icon: 'icon', + icon_type: 'emoji', + }, + description: 'Description', + permission: DatasetPermission.allTeamMembers, + data_source_type: DataSourceType.FILE, + indexing_technique: IndexingType.QUALIFIED, + author_name: 'Author', + created_by: 'creator', + updated_by: 'updater', + updated_at: 1700000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + total_available_documents: 0, + word_count: 0, + provider: 'internal', + embedding_model: 'embed-model', + embedding_model_provider: 'embed-provider', + embedding_available: true, + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-id', + external_knowledge_api_id: 'ext-api-id', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 2, + score_threshold: 0.5, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + doc_metadata: [], + keyword_number: 10, + pipeline_id: 'pipeline-id', + is_published: false, + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + ...overrides, + retrieval_model_dict: { + ...retrievalConfig, + ...overrides.retrieval_model_dict, + }, + retrieval_model: { + ...retrievalConfig, + ...overrides.retrieval_model, + }, + } +} + +describe('RetrievalChangeTip', () => { + const defaultProps = { + visible: true, + message: 'Test message', + onDismiss: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('renders and supports dismiss', async () => { + // Arrange + const onDismiss = jest.fn() + render() + + // Act + await userEvent.click(screen.getByRole('button', { name: 'close-retrieval-change-tip' })) + + // Assert + expect(screen.getByText('Test message')).toBeInTheDocument() + expect(onDismiss).toHaveBeenCalledTimes(1) + }) + + it('does not render when hidden', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByText('Test message')).not.toBeInTheDocument() + }) +}) + +describe('RetrievalSection', () => { + const t = (key: string) => key + const rowClass = 'row' + const labelClass = 'label' + + beforeEach(() => { + jest.clearAllMocks() + mockUseModelList.mockImplementation((type: ModelTypeEnum) => { + if (type === ModelTypeEnum.rerank) + return { data: [{ provider: 'rerank-provider', models: [{ model: 'rerank-model' }] }] } + return { data: [] } + }) + mockUseModelListAndDefaultModel.mockReturnValue({ modelList: [], defaultModel: null }) + mockUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ defaultModel: null, currentModel: null }) + mockUseCurrentProviderAndModel.mockReturnValue({ currentProvider: null, currentModel: null }) + }) + + it('renders external retrieval details and propagates changes', async () => { + // Arrange + const dataset = createDataset({ + provider: 'external', + external_knowledge_info: { + external_knowledge_id: 'ext-id-999', + external_knowledge_api_id: 'ext-api-id-999', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.external.com', + }, + }) + const handleExternalChange = jest.fn() + + // Act + render( + , + ) + const [topKIncrement] = screen.getAllByLabelText('increment') + await userEvent.click(topKIncrement) + + // Assert + expect(screen.getByText('External API')).toBeInTheDocument() + expect(screen.getByText('https://api.external.com')).toBeInTheDocument() + expect(screen.getByText('ext-id-999')).toBeInTheDocument() + expect(handleExternalChange).toHaveBeenCalledWith(expect.objectContaining({ top_k: 4 })) + }) + + it('renders internal retrieval config with doc link', () => { + // Arrange + const docLink = jest.fn((path: string) => `https://docs.example${path}`) + const retrievalConfig = createRetrievalConfig() + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('dataset.retrieval.semantic_search.title')).toBeInTheDocument() + const learnMoreLink = screen.getByRole('link', { name: 'datasetSettings.form.retrievalSetting.learnMore' }) + expect(learnMoreLink).toHaveAttribute('href', 'https://docs.example/guides/knowledge-base/create-knowledge-and-upload-documents/setting-indexing-methods#setting-the-retrieval-setting') + expect(docLink).toHaveBeenCalledWith('/guides/knowledge-base/create-knowledge-and-upload-documents/setting-indexing-methods#setting-the-retrieval-setting') + }) + + it('propagates retrieval config changes for economical indexing', async () => { + // Arrange + const handleRetrievalChange = jest.fn() + + // Act + render( + path} + />, + ) + const [topKIncrement] = screen.getAllByLabelText('increment') + await userEvent.click(topKIncrement) + + // Assert + expect(screen.getByText('dataset.retrieval.keyword_search.title')).toBeInTheDocument() + expect(handleRetrievalChange).toHaveBeenCalledWith(expect.objectContaining({ + top_k: 3, + })) + }) +}) diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/retrieval-section.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/retrieval-section.tsx new file mode 100644 index 0000000000..99d042f681 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/settings-modal/retrieval-section.tsx @@ -0,0 +1,218 @@ +import { RiCloseLine } from '@remixicon/react' +import type { FC } from 'react' +import { cn } from '@/utils/classnames' +import Divider from '@/app/components/base/divider' +import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' +import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' +import type { DataSet } from '@/models/datasets' +import { IndexingType } from '@/app/components/datasets/create/step-two' +import type { RetrievalConfig } from '@/types/app' +import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' +import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' + +type CommonSectionProps = { + rowClass: string + labelClass: string + t: (key: string, options?: any) => string +} + +type ExternalRetrievalSectionProps = CommonSectionProps & { + topK: number + scoreThreshold: number + scoreThresholdEnabled: boolean + onExternalSettingChange: (data: { top_k?: number; score_threshold?: number; score_threshold_enabled?: boolean }) => void + currentDataset: DataSet +} + +const ExternalRetrievalSection: FC = ({ + rowClass, + labelClass, + t, + topK, + scoreThreshold, + scoreThresholdEnabled, + onExternalSettingChange, + currentDataset, +}) => ( + <> +
+
+
+
{t('datasetSettings.form.retrievalSetting.title')}
+
+ +
+
+
+
+
{t('datasetSettings.form.externalKnowledgeAPI')}
+
+
+
+ +
+ {currentDataset?.external_knowledge_info.external_knowledge_api_name} +
+
·
+
{currentDataset?.external_knowledge_info.external_knowledge_api_endpoint}
+
+
+
+
+
+
{t('datasetSettings.form.externalKnowledgeID')}
+
+
+
+
{currentDataset?.external_knowledge_info.external_knowledge_id}
+
+
+
+
+ +) + +type InternalRetrievalSectionProps = CommonSectionProps & { + indexMethod: IndexingType + retrievalConfig: RetrievalConfig + showMultiModalTip: boolean + onRetrievalConfigChange: (value: RetrievalConfig) => void + docLink: (path: string) => string +} + +const InternalRetrievalSection: FC = ({ + rowClass, + labelClass, + t, + indexMethod, + retrievalConfig, + showMultiModalTip, + onRetrievalConfigChange, + docLink, +}) => ( +
+
+
+
{t('datasetSettings.form.retrievalSetting.title')}
+
+ {t('datasetSettings.form.retrievalSetting.learnMore')} + {t('datasetSettings.form.retrievalSetting.description')} +
+
+
+
+ {indexMethod === IndexingType.QUALIFIED + ? ( + + ) + : ( + + )} +
+
+) + +type RetrievalSectionProps + = | (ExternalRetrievalSectionProps & { isExternal: true }) + | (InternalRetrievalSectionProps & { isExternal: false }) + +export const RetrievalSection: FC = (props) => { + if (props.isExternal) { + const { + rowClass, + labelClass, + t, + topK, + scoreThreshold, + scoreThresholdEnabled, + onExternalSettingChange, + currentDataset, + } = props + + return ( + + ) + } + + const { + rowClass, + labelClass, + t, + indexMethod, + retrievalConfig, + showMultiModalTip, + onRetrievalConfigChange, + docLink, + } = props + + return ( + + ) +} + +type RetrievalChangeTipProps = { + visible: boolean + message: string + onDismiss: () => void +} + +export const RetrievalChangeTip: FC = ({ + visible, + message, + onDismiss, +}) => { + if (!visible) + return null + + return ( +
+
+ +
{message}
+
+ +
+ ) +} diff --git a/web/app/components/app/configuration/debug/chat-user-input.tsx b/web/app/components/app/configuration/debug/chat-user-input.tsx index 16666d514e..c25bed548c 100644 --- a/web/app/components/app/configuration/debug/chat-user-input.tsx +++ b/web/app/components/app/configuration/debug/chat-user-input.tsx @@ -7,7 +7,7 @@ import Select from '@/app/components/base/select' import Textarea from '@/app/components/base/textarea' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import type { Inputs } from '@/models/debug' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' type Props = { diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx new file mode 100644 index 0000000000..676456c3ea --- /dev/null +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx @@ -0,0 +1,1011 @@ +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { type ReactNode, type RefObject, createRef } from 'react' +import DebugWithSingleModel from './index' +import type { DebugWithSingleModelRefType } from './index' +import type { ChatItem } from '@/app/components/base/chat/types' +import { ConfigurationMethodEnum, ModelFeatureEnum, ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ProviderContextState } from '@/context/provider-context' +import type { DatasetConfigs, ModelConfig } from '@/models/debug' +import { PromptMode } from '@/models/debug' +import { type Collection, CollectionType } from '@/app/components/tools/types' +import type { FileEntity } from '@/app/components/base/file-uploader/types' +import { AgentStrategy, AppModeEnum, ModelModeType, Resolution, TransferMethod } from '@/types/app' + +// ============================================================================ +// Test Data Factories (Following testing.md guidelines) +// ============================================================================ + +/** + * Factory function for creating mock ModelConfig with type safety + */ +function createMockModelConfig(overrides: Partial = {}): ModelConfig { + return { + provider: 'openai', + model_id: 'gpt-3.5-turbo', + mode: ModelModeType.chat, + configs: { + prompt_template: 'Test template', + prompt_variables: [ + { key: 'var1', name: 'Variable 1', type: 'text', required: false }, + ], + }, + chat_prompt_config: { + prompt: [], + }, + completion_prompt_config: { + prompt: { text: '' }, + conversation_histories_role: { + user_prefix: 'user', + assistant_prefix: 'assistant', + }, + }, + more_like_this: null, + opening_statement: '', + suggested_questions: [], + sensitive_word_avoidance: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + suggested_questions_after_answer: null, + retriever_resource: null, + annotation_reply: null, + external_data_tools: [], + system_parameters: { + audio_file_size_limit: 0, + file_size_limit: 0, + image_file_size_limit: 0, + video_file_size_limit: 0, + workflow_file_upload_limit: 0, + }, + dataSets: [], + agentConfig: { + enabled: false, + max_iteration: 5, + tools: [], + strategy: AgentStrategy.react, + }, + ...overrides, + } +} + +/** + * Factory function for creating mock Collection list + */ +function createMockCollections(collections: Partial[] = []): Collection[] { + return collections.map((collection, index) => ({ + id: `collection-${index}`, + name: `Collection ${index}`, + icon: 'icon-url', + type: 'tool', + ...collection, + } as Collection)) +} + +/** + * Factory function for creating mock Provider Context + */ +function createMockProviderContext(overrides: Partial = {}): ProviderContextState { + return { + textGenerationModelList: [ + { + provider: 'openai', + label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' }, + icon_small: { en_US: 'icon', zh_Hans: 'icon' }, + icon_large: { en_US: 'icon', zh_Hans: 'icon' }, + status: ModelStatusEnum.active, + models: [ + { + model: 'gpt-3.5-turbo', + label: { en_US: 'GPT-3.5', zh_Hans: 'GPT-3.5' }, + model_type: ModelTypeEnum.textGeneration, + features: [ModelFeatureEnum.vision], + fetch_from: ConfigurationMethodEnum.predefinedModel, + model_properties: {}, + deprecated: false, + }, + ], + }, + ], + hasSettedApiKey: true, + modelProviders: [], + speech2textDefaultModel: null, + ttsDefaultModel: null, + agentThoughtDefaultModel: null, + updateModelList: jest.fn(), + onPlanInfoChanged: jest.fn(), + refreshModelProviders: jest.fn(), + refreshLicenseLimit: jest.fn(), + ...overrides, + } as ProviderContextState +} + +// ============================================================================ +// Mock External Dependencies ONLY (Following testing.md guidelines) +// ============================================================================ + +// Mock service layer (API calls) +jest.mock('@/service/base', () => ({ + ssePost: jest.fn(() => Promise.resolve()), + post: jest.fn(() => Promise.resolve({ data: {} })), + get: jest.fn(() => Promise.resolve({ data: {} })), + del: jest.fn(() => Promise.resolve({ data: {} })), + patch: jest.fn(() => Promise.resolve({ data: {} })), + put: jest.fn(() => Promise.resolve({ data: {} })), +})) + +jest.mock('@/service/fetch', () => ({ + fetch: jest.fn(() => Promise.resolve({ ok: true, json: () => Promise.resolve({}) })), +})) + +const mockFetchConversationMessages = jest.fn() +const mockFetchSuggestedQuestions = jest.fn() +const mockStopChatMessageResponding = jest.fn() + +jest.mock('@/service/debug', () => ({ + fetchConversationMessages: (...args: unknown[]) => mockFetchConversationMessages(...args), + fetchSuggestedQuestions: (...args: unknown[]) => mockFetchSuggestedQuestions(...args), + stopChatMessageResponding: (...args: unknown[]) => mockStopChatMessageResponding(...args), +})) + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ push: jest.fn() }), + usePathname: () => '/test', + useParams: () => ({}), +})) + +// Mock complex context providers +const mockDebugConfigContext = { + appId: 'test-app-id', + isAPIKeySet: true, + isTrailFinished: false, + mode: AppModeEnum.CHAT, + modelModeType: ModelModeType.chat, + promptMode: PromptMode.simple, + setPromptMode: jest.fn(), + isAdvancedMode: false, + isAgent: false, + isFunctionCall: false, + isOpenAI: true, + collectionList: createMockCollections([ + { id: 'test-provider', name: 'Test Tool', icon: 'icon-url' }, + ]), + canReturnToSimpleMode: false, + setCanReturnToSimpleMode: jest.fn(), + chatPromptConfig: {}, + completionPromptConfig: {}, + currentAdvancedPrompt: [], + showHistoryModal: jest.fn(), + conversationHistoriesRole: { user_prefix: 'user', assistant_prefix: 'assistant' }, + setConversationHistoriesRole: jest.fn(), + setCurrentAdvancedPrompt: jest.fn(), + hasSetBlockStatus: { context: false, history: false, query: false }, + conversationId: null, + setConversationId: jest.fn(), + introduction: '', + setIntroduction: jest.fn(), + suggestedQuestions: [], + setSuggestedQuestions: jest.fn(), + controlClearChatMessage: 0, + setControlClearChatMessage: jest.fn(), + prevPromptConfig: { prompt_template: '', prompt_variables: [] }, + setPrevPromptConfig: jest.fn(), + moreLikeThisConfig: { enabled: false }, + setMoreLikeThisConfig: jest.fn(), + suggestedQuestionsAfterAnswerConfig: { enabled: false }, + setSuggestedQuestionsAfterAnswerConfig: jest.fn(), + speechToTextConfig: { enabled: false }, + setSpeechToTextConfig: jest.fn(), + textToSpeechConfig: { enabled: false, voice: '', language: '' }, + setTextToSpeechConfig: jest.fn(), + citationConfig: { enabled: false }, + setCitationConfig: jest.fn(), + moderationConfig: { enabled: false }, + annotationConfig: { id: '', enabled: false, score_threshold: 0.7, embedding_model: { embedding_model_name: '', embedding_provider_name: '' } }, + setAnnotationConfig: jest.fn(), + setModerationConfig: jest.fn(), + externalDataToolsConfig: [], + setExternalDataToolsConfig: jest.fn(), + formattingChanged: false, + setFormattingChanged: jest.fn(), + inputs: { var1: 'test input' }, + setInputs: jest.fn(), + query: '', + setQuery: jest.fn(), + completionParams: { max_tokens: 100, temperature: 0.7 }, + setCompletionParams: jest.fn(), + modelConfig: createMockModelConfig({ + agentConfig: { + enabled: false, + max_iteration: 5, + tools: [{ + tool_name: 'test-tool', + provider_id: 'test-provider', + provider_type: CollectionType.builtIn, + provider_name: 'test-provider', + tool_label: 'Test Tool', + tool_parameters: {}, + enabled: true, + }], + strategy: AgentStrategy.react, + }, + }), + setModelConfig: jest.fn(), + dataSets: [], + showSelectDataSet: jest.fn(), + setDataSets: jest.fn(), + datasetConfigs: { + retrieval_model: 'single', + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 4, + score_threshold_enabled: false, + score_threshold: 0.7, + datasets: { datasets: [] }, + } as DatasetConfigs, + datasetConfigsRef: createRef(), + setDatasetConfigs: jest.fn(), + hasSetContextVar: false, + isShowVisionConfig: false, + visionConfig: { enabled: false, number_limits: 2, detail: Resolution.low, transfer_methods: [] }, + setVisionConfig: jest.fn(), + isAllowVideoUpload: false, + isShowDocumentConfig: false, + isShowAudioConfig: false, + rerankSettingModalOpen: false, + setRerankSettingModalOpen: jest.fn(), +} + +jest.mock('@/context/debug-configuration', () => ({ + useDebugConfigurationContext: jest.fn(() => mockDebugConfigContext), +})) + +const mockProviderContext = createMockProviderContext() + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(() => mockProviderContext), +})) + +const mockAppContext = { + userProfile: { + id: 'user-1', + avatar_url: 'https://example.com/avatar.png', + name: 'Test User', + email: 'test@example.com', + }, + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: false, + isCurrentWorkspaceDatasetOperator: false, + mutateUserProfile: jest.fn(), +} + +jest.mock('@/context/app-context', () => ({ + useAppContext: jest.fn(() => mockAppContext), +})) + +type FeatureState = { + moreLikeThis: { enabled: boolean } + opening: { enabled: boolean; opening_statement: string; suggested_questions: string[] } + moderation: { enabled: boolean } + speech2text: { enabled: boolean } + text2speech: { enabled: boolean } + file: { enabled: boolean } + suggested: { enabled: boolean } + citation: { enabled: boolean } + annotationReply: { enabled: boolean } +} + +const defaultFeatures: FeatureState = { + moreLikeThis: { enabled: false }, + opening: { enabled: false, opening_statement: '', suggested_questions: [] }, + moderation: { enabled: false }, + speech2text: { enabled: false }, + text2speech: { enabled: false }, + file: { enabled: false }, + suggested: { enabled: false }, + citation: { enabled: false }, + annotationReply: { enabled: false }, +} +type FeatureSelector = (state: { features: FeatureState }) => unknown + +let mockFeaturesState: FeatureState = { ...defaultFeatures } +jest.mock('@/app/components/base/features/hooks', () => ({ + useFeatures: jest.fn(), +})) + +const mockConfigFromDebugContext = { + pre_prompt: 'Test prompt', + prompt_type: 'simple', + user_input_form: [], + dataset_query_variable: '', + opening_statement: '', + more_like_this: { enabled: false }, + suggested_questions: [], + suggested_questions_after_answer: { enabled: false }, + text_to_speech: { enabled: false }, + speech_to_text: { enabled: false }, + retriever_resource: { enabled: false }, + sensitive_word_avoidance: { enabled: false }, + agent_mode: {}, + dataset_configs: {}, + file_upload: { enabled: false }, + annotation_reply: { enabled: false }, + supportAnnotation: true, + appId: 'test-app-id', + supportCitationHitInfo: true, +} + +jest.mock('../hooks', () => ({ + useConfigFromDebugContext: jest.fn(() => mockConfigFromDebugContext), + useFormattingChangedSubscription: jest.fn(), +})) + +const mockSetShowAppConfigureFeaturesModal = jest.fn() + +jest.mock('@/app/components/app/store', () => ({ + useStore: jest.fn((selector?: (state: { setShowAppConfigureFeaturesModal: typeof mockSetShowAppConfigureFeaturesModal }) => unknown) => { + if (typeof selector === 'function') + return selector({ setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal }) + return mockSetShowAppConfigureFeaturesModal + }), +})) + +// Mock event emitter context +jest.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: jest.fn(() => ({ + eventEmitter: null, + })), +})) + +// Mock toast context +jest.mock('@/app/components/base/toast', () => ({ + useToastContext: jest.fn(() => ({ + notify: jest.fn(), + })), +})) + +// Mock hooks/use-timestamp +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: jest.fn(() => ({ + formatTime: jest.fn((timestamp: number) => new Date(timestamp).toLocaleString()), + })), +})) + +// Mock audio player manager +jest.mock('@/app/components/base/audio-btn/audio.player.manager', () => ({ + AudioPlayerManager: { + getInstance: jest.fn(() => ({ + getAudioPlayer: jest.fn(), + resetAudioPlayer: jest.fn(), + })), + }, +})) + +type MockChatProps = { + chatList?: ChatItem[] + isResponding?: boolean + onSend?: (message: string, files?: FileEntity[]) => void + onRegenerate?: (chatItem: ChatItem, editedQuestion?: { message: string; files?: FileEntity[] }) => void + onStopResponding?: () => void + suggestedQuestions?: string[] + questionIcon?: ReactNode + answerIcon?: ReactNode + onAnnotationAdded?: (annotationId: string, authorName: string, question: string, answer: string, index: number) => void + onAnnotationEdited?: (question: string, answer: string, index: number) => void + onAnnotationRemoved?: (index: number) => void + switchSibling?: (siblingMessageId: string) => void + onFeatureBarClick?: (state: boolean) => void +} + +const mockFile: FileEntity = { + id: 'file-1', + name: 'test.png', + size: 123, + type: 'image/png', + progress: 100, + transferMethod: TransferMethod.local_file, + supportFileType: 'image', +} + +// Mock Chat component (complex with many dependencies) +// This is a pragmatic mock that tests the integration at DebugWithSingleModel level +jest.mock('@/app/components/base/chat/chat', () => { + return function MockChat({ + chatList, + isResponding, + onSend, + onRegenerate, + onStopResponding, + suggestedQuestions, + questionIcon, + answerIcon, + onAnnotationAdded, + onAnnotationEdited, + onAnnotationRemoved, + switchSibling, + onFeatureBarClick, + }: MockChatProps) { + const items = chatList || [] + const suggested = suggestedQuestions ?? [] + return ( +
+
+ {items.map((item: ChatItem) => ( +
+ {item.content} +
+ ))} +
+ {questionIcon &&
{questionIcon}
} + {answerIcon &&
{answerIcon}
} +